Pytorch Add Custom Backward pass for nn.Module Function

482 Views Asked by At

I am re-implementing the Invertible Residual Networks architecture.

class iResNetBlock(nn.Module):
    def __init__(self, input_size, hidden_size):
        self.bottleneck = nn.Sequential(
            LinearContraction(input_size, hidden_size),
            LinearContraction(hidden_size, input_size),
            nn.ReLU(),
        )
        
    def forward(self, x):
        return x + self.bottleneck(x)
    
    def inverse(self, y):
        x = y.clone()

        while not converged:
            # fixed point iteration
            x = y - self.bottleneck(x)
   
        return x

I want to add a custom backward pass to the inverse function. Since it is a fixed point iteration, one can make use of the implicit function theorem to avoid unrolling of the loop, and instead compute the gradient by solving a linear system. This is for example done in the Deep Equilibrium Models architecture.

    def inverse(self, y):
        with torch.no_grad():
            x = y.clone()
            while not converged:
                # fixed point iteration
                x = y - self.bottleneck(x)
   
            return x

    def custom_backward_inverse(self, grad_output):
        pass

How do I register my custom backwards pass for this function? I want that, when I later define some loss such as r = loss(y, model.inverse(other_model(model(x)))), that r.backwards() correctly uses my custom gradient for the inverse call.

Ideally the solution should be torchscript-compatible.

0

There are 0 best solutions below