Pytorch equivalent of `register_buffer` in flax/jax

336 Views Asked by At

I'm searching a way to write the equivalent of the following Pytorch module in Flax but I haven't found a way to do it. The important thing is that the constant should be loadable and saveable upon checkpoint.

class SillyModule(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.register_buffer('constant', torch.randn(1, 128))

    def forward(self, x):
        return torch.matmul(x, self.B)

Does anybody know how to do this? What is the equivalent of register_buffer in flax?

0

There are 0 best solutions below