How to find the indexes of the first $n$ maximum values of a tensor?

34 Views Asked by At

I know that torch.argmax(x, dim = 0) returns the index of the first maximum value in x along dimension 0. But is there an efficient way to return the indexes of the first n maximum values? If there are duplicate values I also want the index of those among the n indexes.

As a concrete example, say x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1]). I would like a function

generalized_argmax(xI torch.tensor, n: int)

such that generalized_argmax(x, 4) returns [0, 2, 4, 5] in this example.

1

There are 1 best solutions below

0
Daraan On BEST ANSWER

To acquire all you need to go over the whole tensor anyway, the most efficient should therefore be to use argsort manually limited to n entries.

>>> x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1])
>>> x.argsort(dim=0, descending=True)[:n]
[2, 4, 0, 5]

Sort it again to get [0, 2, 4, 5] if you need the ascending order of indices.