How can I calculate gradients of each element in a tensor using GradientTape?

367 Views Asked by At

I would like to calculate the gradient of each element in a tensor with respect to a list of watched tensors.

When I use GradientTape's gradient() on y directly, the resulting dy_dx has the dimension of my x. For example:

x = [ tf.constant(3.0), tf.constant(4.0), tf.constant(5.0) ]
with tf.GradientTape(persistent=True) as g:
    g.watch(x)
    y_as_list = [ x[0]*x[1]*x[2], x[0]*x[1]*x[2]*x[0] ]
    y_as_tensor = tf.stack(y_as_list, axis=0)

print("---------------------------")
print("x:", x)
print("y:", y_as_tensor)
print("y:", y_as_list)

dy_dx_from_tensor = g.gradient(y_as_tensor, x, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dy_dx_from_list = g.gradient(y_as_list, x, unconnected_gradients=tf.UnconnectedGradients.ZERO)

print("---------------------------")
print("dy_dx_from_tensor:", dy_dx_from_tensor)
print("dy_dx_from_list:", dy_dx_from_list)

results in:

---------------------------
x: [<tf.Tensor: shape=(), dtype=float32, numpy=3.0>, <tf.Tensor: shape=(), dtype=float32, numpy=4.0>, <tf.Tensor: shape=(), dtype=float32, numpy=5.0>]
y: tf.Tensor([ 60. 180.], shape=(2,), dtype=float32)
y: [<tf.Tensor: shape=(), dtype=float32, numpy=60.0>, <tf.Tensor: shape=(), dtype=float32, numpy=180.0>]
---------------------------
dy_dx_from_tensor: [<tf.Tensor: shape=(), dtype=float32, numpy=140.0>, <tf.Tensor: shape=(), dtype=float32, numpy=60.0>, <tf.Tensor: shape=(), dtype=float32, numpy=48.0>]
dy_dx_from_list: [<tf.Tensor: shape=(), dtype=float32, numpy=140.0>, <tf.Tensor: shape=(), dtype=float32, numpy=60.0>, <tf.Tensor: shape=(), dtype=float32, numpy=48.0>]

Note both the tensor and the list versions' result have the same dimension as the watched x.

When I try to call the tape's gradient method for each element, I get want I want for the list but for the tensor all gradients are zero:

dy_dx_from_tensor_elements = [ g.gradient(y_i, x, unconnected_gradients=tf.UnconnectedGradients.ZERO) for y_i in y_as_tensor ]
dy_dx_from_list_elements = [ g.gradient(y_i, x, unconnected_gradients=tf.UnconnectedGradients.ZERO) for y_i in y_as_list ]

print("---------------------------")
print("dy_dx_from_tensor_elements:", dy_dx_from_tensor_elements)
print("dy_dx_from_list_elements:", dy_dx_from_list_elements)

yeilds:

dy_dx_from_tensor_elements: [[<tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>], [<tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>]]
dy_dx_from_list_elements: [[<tf.Tensor: shape=(), dtype=float32, numpy=20.0>, <tf.Tensor: shape=(), dtype=float32, numpy=15.0>, <tf.Tensor: shape=(), dtype=float32, numpy=12.0>], [<tf.Tensor: shape=(), dtype=float32, numpy=120.0>, <tf.Tensor: shape=(), dtype=float32, numpy=45.0>, <tf.Tensor: shape=(), dtype=float32, numpy=36.0>]]

The dy_dx_from_list_elements values are what I am looking for but I would really like to be able to get them from the tensor because my real world model outputs the y values as a tensor.

Any suggestion to how I could generate gradients for every element in a tensor would be much appreciated!

1

There are 1 best solutions below

3
On BEST ANSWER

I think the problem is coming from iterating over a tensor. A tf.unstack or similar operation might be running internally and all tf operations need to be within the scope of the gradient tape for them to be taken into account. Gradients will be calculated only for a tensor in relation to another tensor that was involved in its calculation. A couple of examples:

import tensorflow as tf

x = [ tf.constant(3.0), tf.constant(4.0), tf.constant(5.0) ]
with tf.GradientTape(persistent=True) as g:
    g.watch(x)
    y_as_list = [ x[0]*x[1]*x[2], x[0]*x[1]*x[2]*x[0] ]
    y_as_tensor = tf.stack(y_as_list, axis=0)
    t = tf.unstack(y_as_tensor)


dy_dx_from_tensor_elements = [g.gradient(y_i, x, unconnected_gradients=tf.UnconnectedGradients.ZERO) for y_i in t]
dy_dx_from_list_elements = [g.gradient(y_i, x, unconnected_gradients=tf.UnconnectedGradients.ZERO) for y_i in y_as_list]

print("---------------------------")
print("dy_dx_from_tensor_elements:", dy_dx_from_tensor_elements)
print("dy_dx_from_list_elements:", dy_dx_from_list_elements)
---------------------------
dy_dx_from_tensor_elements: [[<tf.Tensor: shape=(), dtype=float32, numpy=20.0>, <tf.Tensor: shape=(), dtype=float32, numpy=15.0>, <tf.Tensor: shape=(), dtype=float32, numpy=12.0>], [<tf.Tensor: shape=(), dtype=float32, numpy=120.0>, <tf.Tensor: shape=(), dtype=float32, numpy=45.0>, <tf.Tensor: shape=(), dtype=float32, numpy=36.0>]]
dy_dx_from_list_elements: [[<tf.Tensor: shape=(), dtype=float32, numpy=20.0>, <tf.Tensor: shape=(), dtype=float32, numpy=15.0>, <tf.Tensor: shape=(), dtype=float32, numpy=12.0>], [<tf.Tensor: shape=(), dtype=float32, numpy=120.0>, <tf.Tensor: shape=(), dtype=float32, numpy=45.0>, <tf.Tensor: shape=(), dtype=float32, numpy=36.0>]]

The same applies when you for example use tf.split:

import tensorflow as tf

x = [ tf.constant(3.0), tf.constant(4.0), tf.constant(5.0) ]
with tf.GradientTape(persistent=True) as g:
    g.watch(x)
    y_as_list = [ x[0]*x[1]*x[2], x[0]*x[1]*x[2]*x[0] ]
    y_as_tensor = tf.stack(y_as_list, axis=0)
    t = tf.split(y_as_tensor, 2)

According to the docs:

The tape can't record the gradient path if the calculation exits TensorFlow.

Also, tf.stack is generally not differentiable.