I am re-implementing the Invertible Residual Networks architecture.
class iResNetBlock(nn.Module):
def __init__(self, input_size, hidden_size):
self.bottleneck = nn.Sequential(
LinearContraction(input_size, hidden_size),
LinearContraction(hidden_size, input_size),
nn.ReLU(),
)
def forward(self, x):
return x + self.bottleneck(x)
def inverse(self, y):
x = y.clone()
while not converged:
# fixed point iteration
x = y - self.bottleneck(x)
return x
I want to add a custom backward pass to the inverse
function. Since it is a fixed point iteration, one can make use of the implicit function theorem to avoid unrolling of the loop, and instead compute the gradient by solving a linear system. This is for example done in the Deep Equilibrium Models architecture.
def inverse(self, y):
with torch.no_grad():
x = y.clone()
while not converged:
# fixed point iteration
x = y - self.bottleneck(x)
return x
def custom_backward_inverse(self, grad_output):
pass
How do I register my custom backwards pass for this function? I want that, when I later define some loss such as r = loss(y, model.inverse(other_model(model(x))))
, that r.backwards()
correctly uses my custom gradient for the inverse call.
Ideally the solution should be torchscript
-compatible.