Keras MeanSquaredError calculate loss per individual sample

365 Views Asked by At

I'm trying to get the MeanSquaredError of each individal sample in my tensors.

Here is some sample code to show my problem.

src = np.random.uniform(size=(2, 5, 10))
tgt = np.random.uniform(size=(2, 5, 10))
srcTF = tf.convert_to_tensor(src)
tgtTF = tf.convert_to_tensor(tgt)
print(srcTF, tgtTF)

lf = tf.keras.losses.MeanSquaredError(reduction=tf.compat.v1.losses.Reduction.NONE)

flowResults = lf(srcTF, tgtTF)
print(flowResults)

Here are the results:

(2, 5, 10) (2, 5, 10)
(2, 5)

I want to keep all the original dimensions of my tensors, and just calculate loss on the individual samples. Is there a way to do this in Tensorflow? Note that pytorch's torch.nn.MSELoss(reduction = 'none') does exactly what I want, so is there an alternative that's more like that?

1

There are 1 best solutions below

0
On BEST ANSWER

Here is a way to do it:

[ins] In [97]: mse = tf.keras.losses.MSE(tf.expand_dims(srcTF, axis=-1) , tf.expand_dims(tgtTF, axis=-1))                                                                 
                                                                                                                                                                            
[ins] In [98]: mse.shape                                                                                                                                                    
Out[98]: TensorShape([2, 5, 10])       

I think the key here is samples. Since MSE is being computed on the last axis, you lose that axis as that's what's being "reduced". Each point in that five dimensional vector represents the mean squared error of the 10 dimensions in the last axis. So in order to get back the original shape, essentially, we have to do the MSE of each scalar, for which we need to expand the dimensions. Essentially, we are saying that (2, 5, 10) is the number of batches we have, and each scalar is our sample/prediction, which is what tf.expand_dims(<tensor>, -1) accomplishes.