为什么使用BN?
如果不进行Batch Norm,如果输入weight差别过大,在两个方向进行梯度下降,会出现梯度下降不平衡,在训练过程中不能稳定的收敛。
目前已知的Normalization的方法有4种,对于输入数据为[,C,(H*W)](N代表tensor数量,C代表通道,H代表高,W代表宽。
- Batch Norm:对每一个批次(N个tensor)的每个通道分别计算均值mean和方差var,如[10,4,9] 最终输出是[0,1,2,3]这样的1*4的tensor
- Layer Norm:对于每一个tensor的所有channels进行均值和方差计算
- Instance Norm:对于每个tensor的每个channels分别计算
- Group Norm:引用了group的概念,比如BGR表示一个组 --不常见
BN2d
BN计算过程:
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples::
>>> # With Learnable Parameters
>>> m = nn.BatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm2d(100, affine=False)
>>> input = torch.randn(20, 100, 35, 45)
>>> output = m(input)
具体实现案例:
```python
import torch
import torch.nn as nn
m = nn.BatchNorm2d(2,affine=True) #权重w和偏重将被使用
input = torch.randn(1,2,3,4)
output = m(input)
print("输入图片:")
print(input)
print("归一化权重:")
print(m.weight)
print("归一化的偏重:")
print(m.bias)
print("归一化的输出:")
print(output)
print("输出的尺度:")
print(output.size())
# i = torch.randn(1,1,2)
print("输入的第一个维度:")
print(input[0][0])
firstDimenMean = torch.Tensor.mean(input[0][0])
firstDimenVar= torch.Tensor.var(input[0][0],False) #Bessel's Correction贝塞尔校正不会被使用
print(m.eps)
print("输入的第一个维度平均值:")
print(firstDimenMean)
print("输入的第一个维度方差:")
print(firstDimenVar)
bacthnormone = \
((input[0][0][0][0] - firstDimenMean)/(torch.pow(firstDimenVar+m.eps,0.5) ))\
* m.weight[0] + m.bias[0]
print(bacthnormone)
三、PyTorch 卷积与BatchNorm的融合
融合后代码:
import torch
import torch.nn as nn
import torchvision as tv
class DummyModule(nn.Module):
def __init__(self):
super(DummyModule, self).__init__()
def forward(self, x):
# print("Dummy, Dummy.")
return x
def fuse(conv, bn):
w = conv.weight
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
beta = bn.weight
gamma = bn.bias
if conv.bias is not None:
b = conv.bias
else:
b = mean.new_zeros(mean.shape)
w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
b = (b - mean)/var_sqrt * beta + gamma
fused_conv = nn.Conv2d(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
bias=True)
fused_conv.weight = nn.Parameter(w)
fused_conv.bias = nn.Parameter(b)
return fused_conv
def fuse_module(m):
children = list(m.named_children())
c = None
cn = None
for name, child in children:
if isinstance(child, nn.BatchNorm2d):
bc = fuse(c, child)
m._modules[cn] = bc
m._modules[name] = DummyModule()
c = None
elif isinstance(child, nn.Conv2d):
c = child
cn = name
else:
fuse_module(child)
def test_net(m):
p = torch.randn([1, 3, 224, 224])
import time
s = time.time()
o_output = m(p)
print("Original time: ", time.time() - s)
fuse_module(m)
s = time.time()
f_output = m(p)
print("Fused time: ", time.time() - s)
print("Max abs diff: ", (o_output - f_output).abs().max().item())
assert(o_output.argmax() == f_output.argmax())
# print(o_output[0][0].item(), f_output[0][0].item())
print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())
def test_layer():
p = torch.randn([1, 3, 112, 112])
conv1 = m.conv1
bn1 = m.bn1
o_output = bn1(conv1(p))
fusion = fuse(conv1, bn1)
f_output = fusion(p)
print(o_output[0][0][0][0].item())
print(f_output[0][0][0][0].item())
print("Max abs diff: ", (o_output - f_output).abs().max().item())
print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())
m = tv.models.resnet152(True)
m.eval()
print("Layer level test: ")
test_layer()
print("============================")
print("Module level test: ")
m = tv.models.resnet18(True)
m.eval()
test_net(m)
参考文献:
- 1 https://blog.csdn.net/Haiqiang1995/article/details/90317657
- 2 https://www.jb51.net/article/178552.htm
- 3 https://zhuanlan.zhihu.com/p/49329030