Applying an input-dependent transformation to the gradients during the backward pass (pytorch)

27 Views Asked by At

I have a specific usage of autograd and I want to know your opinion about the efficient way to make it work.

For context, I want to apply a transformation to the gradients before executing the optimization step. This transformation multiplies the gradient of the loss with the gradient of the prediction function.

For a parameter p, a prediction function (neural network) f, and a loss L, the transformation is the following:

grad_p = d(f_p)/d(p) * d(L_p)/d(p)

The problem is that both gradients are input-dependent, which means that the transformation must be applied before the gradient accumulation step during the backward pass.

The naive solution I can think of is using a batch_size of 1, in this case I can proceed as follows:

X, y = data
prediction = model(X)
prediction.backward(retain_graph=True)
gradient = []
For param in model.parameters():
    gradients.append(param.grad)
model.zero_grad()
loss = loss_fn(prediction, y)  # MSE
loss.backward()
For i, param in enumerate(model.parameters()):
    param.grad *= gradients[i]

Now this is very compute-inefficient and I would like to ask you if there is a way to do it by batch applying the transformation on the fly during the backward pass?

I have thought about registering a backward hook but it’s not clear to me how to pass the first gradient to the hook.

Thank you.

0

There are 0 best solutions below