Why is the gradient of `where(x > 1, log(x), 0)` nan?

43 Views Asked by At

Why is the gradient of tf.where(x > 1, tf.math.log(x), 0) nan when x is 0.0, but not when it's -1 or 1?

Minimal example:

import tensorflow as tf

x = tf.constant([-1, 0, 1], tf.float32)

with tf.GradientTape() as g:
  g.watch(x)
  y = tf.where(x > 1, tf.math.log(x), 0)

print(y)

dy_dx = g.gradient(y, x)
print(dy_dx)

Output:

tf.Tensor([0. 0. 0.], shape=(3,), dtype=float32)
tf.Tensor([-0. nan  0.], shape=(3,), dtype=float32)
0

There are 0 best solutions below