How to find the analytical gradient using tensorflow gradienttape

268 Views Asked by At

Suppose we have some function y=x^2

We can then use gradient tape to automatically calculate the gradient for us (when we provide some values of x to tensorflow

x = tf.Variable(3.0)

with tf.GradientTape() as tape:
  y = x**2
  dy_dx = tape.gradient(y, x)

Is there anyway I can find out what did tensorflow do to my input? For example in this case it is easy to find out the dy/dx=2x, does that mean tensorflow will multiply 2 to my input value of x and then return me 6 (which is 3*2)?

I have a very complicated function which I don't know how to differentiate so I want to find insights from tensorflow gradienttape to see how tensorflow works out the derivative using my input of x.

1

There are 1 best solutions below

0
On

One possible option is using tensorboard and also printing the operations of the tf.Graph:

import tensorflow as tf

x = tf.Variable(3.0)
@tf.function
def f(x):
  with tf.GradientTape() as tape:
    y = x**2
  return tape.gradient(y, x)

print(*[tensor for op in f.get_concrete_function(x).graph.get_operations() for tensor in op.values()], sep="\n")

logdir = 'logs/func/'
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
z = f(x)
with writer.as_default():
  tf.summary.trace_export(
      name="func_trace",
      step=0,
      profiler_outdir=logdir)
Tensor("x:0", shape=(), dtype=resource)
Tensor("ReadVariableOp:0", shape=(), dtype=float32)
Tensor("pow/y:0", shape=(), dtype=float32)
Tensor("pow:0", shape=(), dtype=float32)
Tensor("ones:0", shape=(), dtype=float32)
Tensor("gradient_tape/pow/mul:0", shape=(), dtype=float32)
Tensor("gradient_tape/pow/sub/y:0", shape=(), dtype=float32)
Tensor("gradient_tape/pow/sub:0", shape=(), dtype=float32)
Tensor("gradient_tape/pow/Pow:0", shape=(), dtype=float32)
Tensor("gradient_tape/pow/mul_1:0", shape=(), dtype=float32)
Tensor("Identity:0", shape=(), dtype=float32)

Open tensorboard in your terminal:

%load_ext tensorboard
%tensorboard --logdir logs/func

enter image description here