What is the function of FrozenBatchNorm2d in “maskrcnn_benchmark”?

935 Views Asked by At

"maskrcnn_benchmark"s github

Here is the source code for "FrozenBatchNorm2d"

import torch
from torch import nn
class FrozenBatchNorm2d(nn.Module):
    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def forward(self, x):
        scale = self.weight * self.running_var.rsqrt()
        bias = self.bias - self.running_mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)
        return x * scale + bias

When I put this function in my script, I found that this function had almost no effect. Here is my usage

import torch.nn as nn
import torch
class FrozenBatchNorm2d(nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters
    are fixed
    """

    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))


    def forward(self, x):
        scale = self.weight * self.running_var.rsqrt()
        bias = self.bias - self.running_mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)
        print(scale.shape,bias.shape)
        return x * scale + bias

a=FrozenBatchNorm2d((1,2))
a(torch.tensor([1,2,3]))

The running result is different from what I thought. So can someone tell me what this function exactly does? I will appreciate it if someone could help me.

1

There are 1 best solutions below

0
On

"register_buffer" means open an RAM for some parameters which couldn't be optimized or changed during the tranning process, in another word, the "weight","bias","running_mean","running_var" are consistent values. Hence, that is the reason why this rebuild batchnorm method could be called FrozenBatchnorm2d. It's my explan, hope it can help you.