I want to apply a filter to a tensor and remove values that do not meet my criteria. For example, lets say I have a tensor that looks like this:
softmax_tensor = [[ 0.05 , 0.05, 0.2, 0.7], [ 0.25 , 0.25, 0.3, 0.2 ]]
Right now, the classifier picks the argmax
of the tensors to predict:
predictions = [[3],[2]]
But this isn't exactly what I want because I loose information about the confidence of that prediction. I would rather not make a prediction than to make an incorrect prediction. So what I would like to do is return filtered tensors like so:
new_softmax_tensor = [[ 0.05 , 0.05, 0.2, 0.7]]
new_predictions = [[3]]
If this were straight-up python, I'd have no trouble:
new_softmax_tensor = []
new_predictions = []
for idx,listItem in enumerate(softmax_tensor):
# get two highest max values and see if they are far enough apart
M = max(listItem)
M2 = max(n for n in listItem if n!=M)
if M2 - M > 0.3: # just making up a criteria here
new_softmax_tensor.append(listItem)
new_predictions.append(predictions[idx])
but given that tensorflow works on tensors, I'm not sure how to do this - and if I did, would it break the computation graph?
A previous SO post suggested using tf.gather_nd, but in that scenario they already had a tensor that they wated to filter on. I've also looked at tf.cond but still don't understand. I would imagine many other people would benefit from this exact same solution.
Thanks all.
Ok. I've got it sorted out now. Here is a working example.