I have a 2D tensor and I want to get the indices of the top k values. I know about pytorch's topk function. The problem with pytorch's topk function is, it computes the topk values over some dimension. I want to get topk values over both dimensions.
For example for the following tensor
a = torch.tensor([[4, 9, 7, 4, 0],
        [8, 1, 3, 1, 0],
        [9, 8, 4, 4, 8],
        [0, 9, 4, 7, 8],
        [8, 8, 0, 1, 4]])
pytorch's topk function will give me the following.
values, indices = torch.topk(a, 3)
print(indices)
# tensor([[1, 2, 0],
#        [0, 2, 1],
#        [0, 1, 4],
#        [1, 4, 3],
#        [1, 0, 4]])
But I want to get the following
tensor([[0, 1],
        [2, 0],
        [3, 1]])
This is the indices of 9 in the 2D tensor.
Is there any approach to achieve this using pytorch?
                        
Output:
unravel_index