How to use have batch norm not forget batch statistics it just used in Pytorch?

440 Views Asked by At

I am in an unusual setting where I should not use running statistics (as that would be considered cheating e.g. meta-learning). However, I often run a forward pass on a set of points (5 in fact) and then I want to evaluate only on 1 point using the previous statistics but batch norm forgets the batch statistics it just uses. I've tried to hard code the value it should be but I get strange errors (even when I uncomment things like from the pytorch code itself like checking the dimension size).

How do I hardcode the previous batch statistics so that batch norm works on a new single data point and then reset them for a fresh new next batch?

note: I don't want to change the batch norm layer type.

Sample code I tried:

def set_tracking_running_stats(model):
    for attr in dir(model):
        if 'bn' in attr:
            target_attr = getattr(model, attr)
            target_attr.track_running_stats = True
            target_attr.running_mean = torch.nn.Parameter(torch.zeros(target_attr.num_features, requires_grad=False))
            target_attr.running_var = torch.nn.Parameter(torch.ones(target_attr.num_features, requires_grad=False))
            target_attr.num_batches_tracked = torch.nn.Parameter(torch.tensor(0, dtype=torch.long), requires_grad=False)
            # target_attr.reset_running_stats()
    return

my most comment errors:

    raise ValueError('expected 2D or 3D input (got {}D input)'
ValueError: expected 2D or 3D input (got 1D input)

and

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

related

2

There are 2 best solutions below

0
On

I think this is the solution:

Solution is to use mdl.train() it uses batch statistics by itself:

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1.

If track_running_stats is set to False, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well.

https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html

ref: https://discuss.pytorch.org/t/how-to-use-have-batch-norm-not-forget-batch-statistics-it-just-used/103437/4

0
On

In order to get initial state of mean and variance in batch norm (zeros and ones respectively), you should use batch_norm.reset_running_stats().

You can run this function on your model at any given time (in your case when new batch with 5 training examples arrives) and it should do the trick:

def reset_all_running_stats(model):
    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            module.reset_running_stats()