How to print the gradients during training in Tensorflow?

7.3k Views Asked by At

In order to debug a Tensorflow model I need to see whether the gradients get changed or if there any nans in them. Simply printing a variable in Tensorflow does not work because all you see is:

 <tf.Variable 'Model/embedding:0' shape=(8182, 100) dtype=float32_ref>

I tried to use tf.Print class but can't make it work and I wonder if it can be actually used this way. In my model I have a training loop that prints loss values of each epoch:

def run_epoch(session, model, eval_op=None, verbose=False):
    costs = 0.0
    iters = 0
    state = session.run(model.initial_state)
    fetches = {
            "cost": model.cost,
            "final_state": model.final_state,
    }
    if eval_op is not None:
        fetches["eval_op"] = eval_op

    for step in range(model.input.epoch_size):
        feed_dict = {}
        for i, (c, h) in enumerate(model.initial_state):
            feed_dict[c] = state[i].c
            feed_dict[h] = state[i].h

        vals = session.run(fetches, feed_dict)
        cost = vals["cost"]
        state = vals["final_state"]

        costs += cost
        iters += model.input.num_steps

    print("Loss:", costs)

    return costs

Inserting print(model.gradients[0][1]) into this function won't work, so I tried to use the following code right after loss print:

grads = model.gradients[0][1]
x = tf.Print(grads, [grads])
session.run(x)

But I got the following error message:

ValueError: Fetch argument <tf.Tensor 'mul:0' shape=(8182, 100) dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("mul:0", shape=(8182, 100), dtype=float32) is not an element of this graph.)

Which makes sense because tf.Print is indeed not part of the graph. So, I tried using tf.Print after loss computation in actual graph but that didn't work as well and I still got Tensor("Train/Model/mul:0", shape=(8182, 100), dtype=float32).

How can I print the gradients variable inside the training loop in Tensorflow?

1

There are 1 best solutions below

1
On BEST ANSWER

In my experience, the best way to see the gradient flow in tensorflow is not with tf.Print, but with tensorboard. Here's a sample code I used in another problem where gradients were the key issue in learning:

for g, v in grads_and_vars:
  tf.summary.histogram(v.name, v)
  tf.summary.histogram(v.name + '_grad', g)

merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('train_log_layer', tf.get_default_graph())

...

_, summary = sess.run([train_op, merged], feed_dict={I: 2*np.random.rand(1, 1)-1})
if i % 10 == 0:
  writer.add_summary(summary, global_step=i)

This will present you the distribution of gradients over time. By the way, to check for NaN there's a dedicated function in tensorflow: tf.is_nan. Usually, you don't need to check if the gradient is NaN: when it happens, the variable explodes as well and this will be visible in tensorboard clearly.