How to use have batch norm not forget batch statistics it just used in Pytorch?

440 Views Asked by At

I am in an unusual setting where I should not use running statistics (as that would be considered cheating e.g. meta-learning). However, I often run a forward pass on a set of points (5 in fact) and then I want to evaluate only on 1 point using the previous statistics but batch norm forgets the batch statistics it just uses. I've tried to hard code the value it should be but I get strange errors (even when I uncomment things like from the pytorch code itself like checking the dimension size).

How do I hardcode the previous batch statistics so that batch norm works on a new single data point and then reset them for a fresh new next batch?

note: I don't want to change the batch norm layer type.

Sample code I tried:

def set_tracking_running_stats(model):
    for attr in dir(model):
        if 'bn' in attr:
            target_attr = getattr(model, attr)
            target_attr.track_running_stats = True
            target_attr.running_mean = torch.nn.Parameter(torch.zeros(target_attr.num_features, requires_grad=False))
            target_attr.running_var = torch.nn.Parameter(torch.ones(target_attr.num_features, requires_grad=False))
            target_attr.num_batches_tracked = torch.nn.Parameter(torch.tensor(0, dtype=torch.long), requires_grad=False)
            # target_attr.reset_running_stats()

my most comment errors:

    raise ValueError('expected 2D or 3D input (got {}D input)'
ValueError: expected 2D or 3D input (got 1D input)


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)



There are 2 best solutions below


I think this is the solution:

Solution is to use mdl.train() it uses batch statistics by itself:

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1.

If track_running_stats is set to False, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well.



In order to get initial state of mean and variance in batch norm (zeros and ones respectively), you should use batch_norm.reset_running_stats().

You can run this function on your model at any given time (in your case when new batch with 5 training examples arrives) and it should do the trick:

def reset_all_running_stats(model):
    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d):