How to make a selective back-propagation in a mini-batch in Tensorflow?

308 Views Asked by At

Recently, I'm working on a project "predicting future trajectories of objects from their past trajectories by using LSTMs in Tensorflow." (Here, a trajectory means a sequence of 2D positions.)

Input to the LSTM is, of course, 'past trajectories' and output is 'future trajectories'.

The size of mini-batch is fixed when training. However, the number of past trajectories in a mini-batch can be different. For example, let the mini-batch size be 10. If I have only 4 past trajectories for the current training iteration, 6 out of 10 in the mini-batch is padded with zero value.

When calculating the loss for the back-propagation, I let the loss from the 6 be zero so that the only 4 contribute to the back-propagation.

The problem that I concern is..it seems that Tensorflow still calculates gradients for the 6 even if their loss is zero. As a result, the training speed becomes slower as I increase the mini-batch size even if I used the same training data.

I also used tf.where function when calculating the loss. However, the training time does not decrease.

How can I reduce the training time?

Here I attached my pseudo code for training.

# For each frame in a sequence
for f in range(pred_length):

    # For each element in a batch
    for b in range(batch_size):


        with tf.variable_scope("rnnlm") as scope:
            if (f > 0 or b > 0):
                scope.reuse_variables()

            # for each pedestrian in an element
            for p in range(MNP):

                # ground-truth position
                cur_gt_pose = ...

                # loss mask
                loss_mask_ped = ... # '1' or '0'

                # go through RNN decoder
                output_states_dec_list[b][p], zero_states_dec_list[b][p] = cell_dec(cur_embed_frm_dec,
                                                                                    zero_states_dec_list[b][p])

                # fully connected layer for output
                cur_pred_pose_dec = tf.nn.xw_plus_b(output_states_dec_list[b][p], output_wd, output_bd)

                # go through embedding function for the next input
                prev_embed_frms_dec_list[b][p] = tf.reshape(tf.nn.relu(tf.nn.xw_plus_b(cur_pred_pose_dec, embedding_wd, embedding_bd)), shape=(1, rnn_size))

                # calculate MSE loss
                mse_loss = tf.reduce_sum(tf.pow(tf.subtract(cur_pred_pose_dec, cur_gt_pose_dec), 2.0))

                # only valid ped's traj contributes to the loss
                self.loss += tf.multiply(mse_loss, loss_mask_ped)
1

There are 1 best solutions below

1
On

I think you're looking for the function tf.stop_gradient. Using this, you could do something like tf.where(loss_mask, tensor, tf.stop_gradient(tensor)) to achieve the desired result, assuming that the dimensions are correct.

However, it looks like this is probably not your issue. It seems as though for each item in your dataset, you are defining new graph nodes. This is not how TensorFlow is supposed to function, you should only have one graph, built beforehand that performs some fixed function, regardless of the batch size. You should definitely not be defining new nodes for every element in the batch, since that cannot efficiently take advantage of parallelism.