I'm searching a way to write the equivalent of the following Pytorch module in Flax but I haven't found a way to do it. The important thing is that the constant should be loadable and saveable upon checkpoint.
class SillyModule(nn.Module):
def __init__(self, ):
super().__init__()
self.register_buffer('constant', torch.randn(1, 128))
def forward(self, x):
return torch.matmul(x, self.B)
Does anybody know how to do this? What is the equivalent of register_buffer
in flax?