Say I have a tensor named attn_weights
of size [1,a], entries of which indicate the attention weights between the given query and |a| keys. I want to select the largest one using torch.nn.functional.gumbel_softmax
.
I find docs about this function describe the parameter as logits - […, num_features] unnormalized log probabilities. I wonder whether should I take log
of attn_weights
before passing it into gumbel_softmax
? And I find Wiki defines logit=lg(p/1-p)
, which is different from barely logrithm. I wonder which one should I pass to the function?
Further, I wonder how to choose tau
in gumbel_softmax
, any guidelines?
If
attn_weights
are probabilities (sum to 1; e.g., output of a softmax), then yes. Otherwise, no.Usually, it requires tuning. The references provided in the docs can help you with that.
From Categorical Reparameterizaion with Gumbel-Softmax:
Figure 1, caption:
Section 2.2, 2nd paragraph (emphasis mine):
Lastly, they remind the reader that tau can be learned: