How do you decode one-hot labels in Tensorflow?

31.9k Views Asked by At

Been looking, but can't seem to find any examples of how to decode or convert back to a single integer from a one-hot value in TensorFlow.

I used tf.one_hot and was able to train my model but am a bit confused on how to make sense of the label after my classification. My data is being fed in via a TFRecords file that I created. I thought about storing a text label in the file but wasn't able to get it to work. It appeared as if TFRecords couldn't store text string or maybe I was mistaken.

4

There are 4 best solutions below

0
On BEST ANSWER

You can find out the index of the largest element in the matrix using tf.argmax. Since your one hot vector will be one dimensional and will have just one 1 and other 0s, This will work assuming you are dealing with a single vector.

index = tf.argmax(one_hot_vector, axis=0)

For the more standard matrix of batch_size * num_classes, use axis=1 to get a result of size batch_size * 1.

1
On

Since a one-hot encoding is typically just a matrix with batch_size rows and num_classes columns, and each row is all zero with a single non-zero corresponding to the chosen class, you can use tf.argmax() to recover a vector of integer labels:

BATCH_SIZE = 3
NUM_CLASSES = 4
one_hot_encoded = tf.constant([[0, 1, 0, 0],
                               [1, 0, 0, 0],
                               [0, 0, 0, 1]])

# Compute the argmax across the columns.
decoded = tf.argmax(one_hot_encoded, axis=1)

# ...
print sess.run(decoded)  # ==> array([1, 0, 3])
0
On
data = np.array([1, 5, 3, 8])
print(data)


def encode(data):
    print('Shape of data (BEFORE encode): %s' % str(data.shape))
    encoded = to_categorical(data)
    print('Shape of data (AFTER  encode): %s\n' % str(encoded.shape))
    return encoded


encoded_data = encode(data)
print(encoded_data)

def decode(datum):
    return np.argmax(datum)

decoded_Y = []
print("****************************************")
for i in range(encoded_data.shape[0]):
    datum = encoded_data[i]
    print('index: %d' % i)
    print('encoded datum: %s' % datum)
    decoded_datum = decode(encoded_data[i])
    print('decoded datum: %s' % decoded_datum)
    decoded_Y.append(decoded_datum)


print("****************************************")

print(decoded_Y)
0
On

tf.argmax is depreciated (all links within the answers on this page are thus 404) and now tf.math.argmax should be used .

Usage:

import tensorflow as tf
a = [1, 10, 26.9, 2.8, 166.32, 62.3]
b = tf.math.argmax(input = a)
c = tf.keras.backend.eval(b)
# c = 4
# here a[4] = 166.32 which is the largest element of a across axis 0

Note: You can also do this with numpy.