one hot encode labels of tf.data.Dataset

2.8k Views Asked by At

I am trying to convert the labels of a tf.data.Dataset to one hot encoded labels. I am using this dataset. I've added titles (sentiment, text) to the columns, everything else is original.

Here is the code I use to encode the labels (positive, negative, neutral) to one hot (3,):

def _map_func(text, labels):
   labels_enc = []
   for label in labels:
      if label=='negative':
         label = -1
      elif label=='neutral':
         label = 0
      else: 
         label = 1

      label = tf.one_hot(
         label, 3, name='label', axis=-1)

      labels_enc.append(label)

   return text, labels_enc

raw_train_ds = tf.data.experimental.make_csv_dataset(
   './data/sentiment_data/train.csv', BATCH_SIZE, column_names=['sentiment', 'text'],
   label_name='sentiment', header=True
)

train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)

train_ds = train_ds.map(_map_func)

I am getting the error: ValueError: Value [<tf.Tensor 'while/label:0' shape=(3,) dtype=float32>] is not convertible to a tensor with dtype <dtype: 'float32'> and shape (1, 3).

The second argument for the _map_func(text, label) label has the shape (64,) type=string.

If I understood tensorflows tf.data.Dataset.map function correctly it creates a new dataset with the transformations applied by the transformation function. But as the error states the column for the labels can't be converted from a column with one string to a column with a list containing 3 floats. Is there any way to force the type of the new column to accept the encoded labels?

Thanks for the help :)

2

There are 2 best solutions below

1
On

The mapping function is applied per element, so you don't need to create a list and loop through the batch items. Try it for one sample only:

def _map_func(text, label):
    if label=='negative':
        label = -1
    elif label=='neutral':
        label = 0
    else: 
        label = 1

    label = tf.one_hot(label, 3, name='label', axis=-1)

   return text, label
0
On

I solved the issue by using a TensorFlow TensorArray like so:

def _map_func(text, labels):
    i=0
    labels_enc = tf.TensorArray(tf.float32, size=0, dynamic_size=True,
        clear_after_read=False)
    for label in labels:
        if label=='negative':
            label = tf.constant(-1)
        elif label=='neutral':
            label = tf.constant(0)
        else: 
            label = tf.constant(1)

        label = tf.one_hot(
            label, 3, name='label', axis=-1)

        labels_enc.write(i, label)
            i = i+1

    return text, labels_enc.concat()