Retrieving last value of LSTM sequence in Tensorflow

1.6k Views Asked by At

I have sequences of different lengths that I want to classify using LSTMs in Tensorflow. For the classification I just need the LSTM output of the last timestep of each sequence.

max_length = 10
n_dims = 2
layer_units = 5
input = tf.placeholder(tf.float32, [None, max_length, n_dims])
lengths =  tf.placeholder(tf.int32, [None])
cell = tf.nn.rnn_cell.LSTMCell(num_units=layer_units, state_is_tuple=True)

sequence_outputs, last_states = tf.nn.dynamic_rnn(cell, sequence_length=lengths, inputs=input)

I would like to get, in NumPy notation: output = sequence_outputs[:,lengths]

Is there any way or workaround to get this behaviour in Tensorflow?

---UPDATE---

Following this post How to select rows from a 3-D Tensor in TensorFlow? it seems that is possible to solve the problem in an efficient manner with tf.gather and manipulating the indices. The only requirement is that the batch size must be known in advance. Here is the adaptation of the referred post to this concrete problem:

max_length = 10
n_dims = 2
layer_units = 5
batch_size = 2
input = tf.placeholder(tf.float32, [batch_size, max_length, n_dims])
lengths =  tf.placeholder(tf.int32, [batch_size])
cell = tf.nn.rnn_cell.LSTMCell(num_units=layer_units, state_is_tuple=True)

sequence_outputs, last_states = tf.nn.dynamic_rnn(cell,
                                                  sequence_length=lengths, inputs=input)

#Code adapted from @mrry response in StackOverflow:
#https://stackoverflow.com/questions/36088277/how-to-select-rows-from-a-3-d-tensor-in-tensorflow
rows_per_batch = tf.shape(input)[1]
indices_per_batch = 1

# Offset to add to each row in indices. We use `tf.expand_dims()` to make 
# this broadcast appropriately.
offset = tf.range(0, batch_size) * rows_per_batch

# Convert indices and logits into appropriate form for `tf.gather()`. 
flattened_indices = lengths - 1 + offset
flattened_sequence_outputs = tf.reshape(self.sequence_outputs, tf.concat(0, [[-1],
                             tf.shape(sequence_outputs)[2:]]))

selected_rows = tf.gather(flattened_sequence_outputs, flattened_indices)
last_output  = tf.reshape(selected_rows,
                          tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
                          tf.shape(self.sequence_outputs)[2:]]))

@petrux option (Get the last output of a dynamic_rnn in TensorFlow) seems also to work but the need of building a list within a for loop may be less optimized, although I did not perform any benchmark to support this statement.

2

There are 2 best solutions below

1
On BEST ANSWER

This could be an answer. I don't think there is anything similar to the NumPy notation you pointed out, but the effect is the same.

0
On

Here's a solution, using gather_nd, where batch size does not need to be known ahead of time.

def extract_axis_1(data, ind):
    """
    Get specified elements along the first axis of tensor.
    :param data: Tensorflow tensor that will be subsetted.
    :param ind: Indices to take (one for each element along axis 0 of data).
    :return: Subsetted tensor.
    """

    batch_range = tf.range(tf.shape(data)[0])
    indices = tf.stack([batch_range, ind], axis=1)
    res = tf.gather_nd(data, indices)

    return res

output = extract_axis_1(sequence_outputs, lengths - 1)

Now output is a tensor of dimension [batch_size, num_cells].