Is there a differentiable approximation of argmax() function

179 Views Asked by At

Output a vector corresponding to which the position of the maximum value in the input vector approaches 1 and the other positions approach 0. This function needs to be differentiable.

For example:

import torch
a = torch.Tensor([2,3,4,5,4,3,2], require_grad=True)
b = f(a)
>>> b
tensor([1e-10, 1e-10, 1e-10, 0.9999999, 1e-10, 1e-10, 1e-10], grad_fn=<...>)

Is there a practical way?

0

There are 0 best solutions below