stop_gradient in tensorflow

2.7k Views Asked by At

I am wondering if tf.stop_gradient stops the gradient computation of just a given op, or stops the update of its input tf.variable ? I have the following problem - During the forward path computation in MNIST, I would like to perform a set of operations on the weights (let's say W to W*) and then do a matmul with inputs. However, I would like to exclude these operations from the backward path. I want only dE/dW computed during training with back propagation. The code I wrote prevents W from getting updated. Could you please help me understand why ? If these were variables, I understand I should set their trainable property to false, but these are operations on weights. If stop_gradient cannot be used for this purpose, then how do I build two graphs, one for forward path and the other for back propagation ?

def build_layer(inputs, fmap, nscope,layer_size1,layer_size2, faulty_training):  
  with tf.name_scope(nscope): 
    if (faulty_training):
      ## trainable weight
      weights_i = tf.Variable(tf.truncated_normal([layer_size1, layer_size2],stddev=1.0 / math.sqrt(float(layer_size1))),name='weights_i')
      ## Operations on weight whose gradient should not be computed during backpropagation
      weights_fx_t = tf.multiply(268435456.0,weights_i)
      weight_fx_t = tf.stop_gradient(weights_fx_t)
      weights_fx = tf.cast(weights_fx_t,tf.int32)
      weight_fx = tf.stop_gradient(weights_fx)
      weights_fx_fault = tf.bitwise.bitwise_xor(weights_fx,fmap)
      weight_fx_fault = tf.stop_gradient(weights_fx_fault)
      weights_fl = tf.cast(weights_fx_fault, tf.float32)
      weight_fl = tf.stop_gradient(weights_fl)
      weights = tf.stop_gradient(tf.multiply((1.0/268435456.0),weights_fl))
      ##### end transformation
    else:
      weights = tf.Variable(tf.truncated_normal([layer_size1, layer_size2],stddev=1.0 / math.sqrt(float(layer_size1))),name='weights')


    biases = tf.Variable(tf.zeros([layer_size2]), name='biases')
    hidden = tf.nn.relu(tf.matmul(inputs, weights) + biases)
    return weights,hidden

I am using the tensorflow gradient descent optimizer to do the training.

optimizer = tf.train.GradientDescentOptimizer(learning_rate) 
global_step = tf.Variable(0, name='global_step', trainable=False) 
train_op = optimizer.minimize(loss, global_step=global_step)
1

There are 1 best solutions below

1
On

Stop gradient will prevent the backpropagation from continuing past that node in the graph. You code doesn't have any path from weights_i to the loss except the one that goes through weights_fx_t where the gradient is stopped. This is what is causing weights_i not to be updated during training. You don't need to put stop_gradient after every step. Using it just once will stop the backpropagation there.

If stop_gradient doesn't do what you want then you can get the gradients by doing tf.gradients and you can write your own update op by using tf.assign. This will allow you to alter the gradients however you want.