Recently, I'm working on a project "predicting future trajectories of objects from their past trajectories by using LSTMs in Tensorflow." (Here, a trajectory means a sequence of 2D positions.)
Input to the LSTM is, of course, 'past trajectories' and output is 'future trajectories'.
The size of mini-batch is fixed when training. However, the number of past trajectories in a mini-batch can be different. For example, let the mini-batch size be 10. If I have only 4 past trajectories for the current training iteration, 6 out of 10 in the mini-batch is padded with zero value.
When calculating the loss for the back-propagation, I let the loss from the 6 be zero so that the only 4 contribute to the back-propagation.
The problem that I concern is..it seems that Tensorflow still calculates gradients for the 6 even if their loss is zero. As a result, the training speed becomes slower as I increase the mini-batch size even if I used the same training data.
I also used tf.where function when calculating the loss. However, the training time does not decrease.
How can I reduce the training time?
Here I attached my pseudo code for training.
# For each frame in a sequence
for f in range(pred_length):
# For each element in a batch
for b in range(batch_size):
with tf.variable_scope("rnnlm") as scope:
if (f > 0 or b > 0):
scope.reuse_variables()
# for each pedestrian in an element
for p in range(MNP):
# ground-truth position
cur_gt_pose = ...
# loss mask
loss_mask_ped = ... # '1' or '0'
# go through RNN decoder
output_states_dec_list[b][p], zero_states_dec_list[b][p] = cell_dec(cur_embed_frm_dec,
zero_states_dec_list[b][p])
# fully connected layer for output
cur_pred_pose_dec = tf.nn.xw_plus_b(output_states_dec_list[b][p], output_wd, output_bd)
# go through embedding function for the next input
prev_embed_frms_dec_list[b][p] = tf.reshape(tf.nn.relu(tf.nn.xw_plus_b(cur_pred_pose_dec, embedding_wd, embedding_bd)), shape=(1, rnn_size))
# calculate MSE loss
mse_loss = tf.reduce_sum(tf.pow(tf.subtract(cur_pred_pose_dec, cur_gt_pose_dec), 2.0))
# only valid ped's traj contributes to the loss
self.loss += tf.multiply(mse_loss, loss_mask_ped)
I think you're looking for the function tf.stop_gradient. Using this, you could do something like
tf.where(loss_mask, tensor, tf.stop_gradient(tensor))
to achieve the desired result, assuming that the dimensions are correct.However, it looks like this is probably not your issue. It seems as though for each item in your dataset, you are defining new graph nodes. This is not how TensorFlow is supposed to function, you should only have one graph, built beforehand that performs some fixed function, regardless of the batch size. You should definitely not be defining new nodes for every element in the batch, since that cannot efficiently take advantage of parallelism.