I know that torch.argmax(x, dim = 0) returns the index of the first maximum value in x along dimension 0. But is there an efficient way to return the indexes of the first n maximum values? If there are duplicate values I also want the index of those among the n indexes.
As a concrete example, say x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1]). I would like a function
generalized_argmax(xI torch.tensor, n: int)
such that
generalized_argmax(x, 4)
returns [0, 2, 4, 5] in this example.
To acquire all you need to go over the whole tensor anyway, the most efficient should therefore be to use
argsortmanually limited tonentries.Sort it again to get
[0, 2, 4, 5]if you need the ascending order of indices.