What is argument "grad_outputs" in chainer's backward function

360 Views Asked by At

3 questions:

  1. what is grad_outputs in chainer?

  2. one example in chainer's function F.transpose, how to explain this backward code?

    def backward(self, inputs, grad_outputs): gy = grad_outputs[0] inv_axes = self.axes if self.axes: axes = tuple(ax % len(self.axes) for ax in self.axes) inv_axes = tuple(numpy.argsort(axes)) gx = gy.transpose(inv_axes) return gx,

  3. suppose I want implement self define function, but my inputs[0] and inputs[1] have different shape, in order to back propagation using differential chain rule, I have to write following code in backward:

    a, b = inputs gy = grad_outputs[0] return a * gy, b * gy But, a and b is not same shape, and a * gy and b * gy maybe report error? shape doesn't match to multiply?

1

There are 1 best solutions below

4
On

*This answer applies to chainer v2, the Function class's internal behavior may change after chainer v3 to support differentiable backpropagation.

Back propagation proceeds from final layer to first layer to propagate its gradients in order to calculate gradient for each layer's parameters.

The function's backward function receives gradient of output, and need to calculate & return gradient of input.

  1. grad_outputs is the gradient for this function's output, in array (numpy or cupy) form.
  2. I believe the basic idea is, F.transpose's differentiation is also just a transpose, so it is just returning the transpose of gradient of output, gy. However rigorously, F.transpose's transpose order is specified when we forward the computation, this order is kept as self.axes and in it needs to be reverse ordered in backward computation. I guess inv_axes is the reversely ordered axes and it is used to calculate gradient of input, written as gx.
  3. As you wrote, you can return gradient of inputs in tuple format like return a * gy, b * gy. Shape does not matter and it can be different for each function's input (as well as the return values of backward)