How to do PyTorch F.cross_entropy?

241 Views Asked by At

I have output with dimension batch_size x 14 x 100 (14 object x 100 classes). I want to do cross-entropy loss with ground truth indices provided with dimension batch_size x 14. However, when I use torch.functional.cross_entropy, I get the error message that says Expected target size [15, 100], got [15, 14]. Does anyone know what the reason is? Thank you in advance

1

There are 1 best solutions below

0
On BEST ANSWER

See the cross entropy documentation

For higher dim inputs, the inputs and targets are expected to be of size (N, C, d_1, ... d_k) and (N, d_1, ... d_k) where N is the batch size and C is the number of classes.

Your output should be of shape batch_size x 100 x 14 rather than batch_size x 14 x 100