Here is the source code for "FrozenBatchNorm2d"
import torch
from torch import nn
class FrozenBatchNorm2d(nn.Module):
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def forward(self, x):
scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return x * scale + bias
When I put this function in my script, I found that this function had almost no effect. Here is my usage
import torch.nn as nn
import torch
class FrozenBatchNorm2d(nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters
are fixed
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def forward(self, x):
scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
print(scale.shape,bias.shape)
return x * scale + bias
a=FrozenBatchNorm2d((1,2))
a(torch.tensor([1,2,3]))
The running result is different from what I thought. So can someone tell me what this function exactly does? I will appreciate it if someone could help me.
"register_buffer" means open an RAM for some parameters which couldn't be optimized or changed during the tranning process, in another word, the "weight","bias","running_mean","running_var" are consistent values. Hence, that is the reason why this rebuild batchnorm method could be called FrozenBatchnorm2d. It's my explan, hope it can help you.