Tensorflow loses track of variables/gradients after multiplication with constant tensor

213 Views Asked by At

I have a tensorflow model with some custom tensorflow layer. I build my tf.Variables in the build() method by calling self.add_weight() as it should be done. I then multiply these weights with some other constant tensor before calling (consider it basis change). It appears that tensorflow loses track of my variables. They don't disappear in my layers' trainable variables however. This is an example that reproduces my error for what I want to do:

class ToyLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(ToyLayer, self).__init__()
        self.basis_vector = tf.constant([1, 0., 1])

    def build(self, input_shape):
        self.variable = self.add_weight(shape=(1,))
        self.effective_weight = self.variable*self.basis_vector

    def call(self, inputs, **kwargs):
        return tf.tensordot(inputs, self.effective_weight, axes=1)


layer = ToyLayer()
x = tf.random.normal((3,))
with tf.GradientTape() as tape:
    y = layer(x)
print(layer.trainable_weights)
print(tape.gradient(y, layer.trainable_weights))

The trainable weights are still what they need to be, but for the gradient I get None. Changing the constant tensor to a tf.Variable doesn't help.

If I try to do some similar things with tf.GradientTape() I get the right gradient if I multiply the variable with the vector in the gradienttape but also no gradient if the vector-variable multiplication is done before taping. So in the layer it appears that my gradient is not yet taped when multiplying the variable with the vector. How can I fix this?

1

There are 1 best solutions below

0
On

I encountered the same problem, and for me, moving the multiplication from build() to call() worked. In your case, the line self.effective_weight = self.variable*self.basis_vector.

Relevant links:

  1. https://www.tensorflow.org/guide/autodiff#1_replaced_a_variable_with_a_tensor
  2. Tensorflow, tf.reshape causes "Gradients do not exist for variables"

(I know it's a bit late, but hope it helps others!)