I'm implementing the forward pass for a simple maxpool 2d function(kernel_size always square, and stride always = kernel_size) and I'm having trouble getting the correct indices for the max values. The indices my function gets, are different than the ones that pytorch's maxpool2d function generates. Let's say for example that the input is: tensor([[[[1.00, 2.00, 5.00, 6.00], [3.00, 4.00, 7.00, 8.00]]]]). My implementation will transform the tensor so that each kernel sized block is in it's own sub-array: tensor([[[[1.00, 2.00, 3.00, 4.00], [5.00, 6.00, 7.00, 8.00]]]]) And then it will return the indices for each block. [[[[3, 3]]]] That's how I thought it should be implemented, but the pytorch implementation calculates the indices based on the original input tensor, not the reshaped one. So in this case the indices should be: [[[[5, 7]]]] because 4.00 is the 6th element, and 8.00 is the 8th element. Here is my current code:
def my_max_pool2d(input, kernel_size):
batch_size, channels, height, width = input.size()
out_height = height // kernel_size
out_width = width // kernel_size
input_reshaped = input.unfold(2, kernel_size, kernel_size).unfold(3, kernel_size, kernel_size)
input_reshaped = input_reshaped.contiguous().view(batch_size, channels, out_height, out_width, -1) # [[[[1.00, 2.00, 3.00, 4.00], [5.00, 6.00, 7.00, 8.00]]]]
output, indices = input_reshaped.max(dim=-1) # indices = [[[[3, 3]]]]
return output, indices
How can I convert the indices to be based on the original input tensor? I went to the trouble of unfolding and reshaping the input based on info from a pytorch thread about a custom maxpool function like mine. Why is that needed if I am supposed to get the indices from the original input tensor anyways? Just to get the max values? I've tried getting just the max values from my reshaped tensor, and then getting the index of those values from the original input, but that only works if there aren't duplicate values in the input tensor, which happens quite often.