BatchNorm(Pytorch )

为什么使用BN?
如果不进行Batch Norm,如果输入weight差别过大,在两个方向进行梯度下降,会出现梯度下降不平衡,在训练过程中不能稳定的收敛。

在这里插入图片描述
目前已知的Normalization的方法有4种,对于输入数据为[,C,(H*W)](N代表tensor数量,C代表通道,H代表高,W代表宽。

  • Batch Norm:对每一个批次(N个tensor)的每个通道分别计算均值mean和方差var,如[10,4,9] 最终输出是[0,1,2,3]这样的1*4的tensor
  • Layer Norm:对于每一个tensor的所有channels进行均值和方差计算
  • Instance Norm:对于每个tensor的每个channels分别计算
  • Group Norm:引用了group的概念,比如BGR表示一个组 --不常见
    在这里插入图片描述

BN2d

BN计算过程:
在这里插入图片描述

 Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)
    Examples::
        >>> # With Learnable Parameters
        >>> m = nn.BatchNorm2d(100)
        >>> # Without Learnable Parameters
        >>> m = nn.BatchNorm2d(100, affine=False)
        >>> input = torch.randn(20, 100, 35, 45)
        >>> output = m(input)

具体实现案例:

```python
import torch

import torch.nn as nn

m = nn.BatchNorm2d(2,affine=True) #权重w和偏重将被使用
input = torch.randn(1,2,3,4)
output = m(input)

print("输入图片:")
print(input)
print("归一化权重:")
print(m.weight)
print("归一化的偏重:")
print(m.bias)

print("归一化的输出:")
print(output)
print("输出的尺度:")
print(output.size())

# i = torch.randn(1,1,2)
print("输入的第一个维度:")
print(input[0][0])
firstDimenMean = torch.Tensor.mean(input[0][0])
firstDimenVar= torch.Tensor.var(input[0][0],False) #Bessel's Correction贝塞尔校正不会被使用

print(m.eps)
print("输入的第一个维度平均值:")
print(firstDimenMean)
print("输入的第一个维度方差:")
print(firstDimenVar)

bacthnormone = \
  ((input[0][0][0][0] - firstDimenMean)/(torch.pow(firstDimenVar+m.eps,0.5) ))\
        * m.weight[0] + m.bias[0]
print(bacthnormone)


在这里插入图片描述

三、PyTorch 卷积与BatchNorm的融合

在这里插入图片描述
融合后代码:

import torch
import torch.nn as nn
import torchvision as tv


class DummyModule(nn.Module):
    def __init__(self):
        super(DummyModule, self).__init__()

    def forward(self, x):
        # print("Dummy, Dummy.")
        return x


def fuse(conv, bn):
    w = conv.weight
    mean = bn.running_mean
    var_sqrt = torch.sqrt(bn.running_var + bn.eps)

    beta = bn.weight
    gamma = bn.bias

    if conv.bias is not None:
        b = conv.bias
    else:
        b = mean.new_zeros(mean.shape)

    w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
    b = (b - mean)/var_sqrt * beta + gamma
    fused_conv = nn.Conv2d(conv.in_channels,
                         conv.out_channels,
                         conv.kernel_size,
                         conv.stride,
                         conv.padding,
                         bias=True)
    fused_conv.weight = nn.Parameter(w)
    fused_conv.bias = nn.Parameter(b)
    return fused_conv


def fuse_module(m):
    children = list(m.named_children())
    c = None
    cn = None

    for name, child in children:
        if isinstance(child, nn.BatchNorm2d):
            bc = fuse(c, child)
            m._modules[cn] = bc
            m._modules[name] = DummyModule()
            c = None
        elif isinstance(child, nn.Conv2d):
            c = child
            cn = name
        else:
            fuse_module(child)


def test_net(m):
    p = torch.randn([1, 3, 224, 224])
    import time
    s = time.time()
    o_output = m(p)
    print("Original time: ", time.time() - s)

    fuse_module(m)

    s = time.time()
    f_output = m(p)
    print("Fused time: ", time.time() - s)

    print("Max abs diff: ", (o_output - f_output).abs().max().item())
    assert(o_output.argmax() == f_output.argmax())
    # print(o_output[0][0].item(), f_output[0][0].item())
    print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())


def test_layer():
    p = torch.randn([1, 3, 112, 112])
    conv1 = m.conv1
    bn1 = m.bn1
    o_output = bn1(conv1(p))
    fusion = fuse(conv1, bn1)
    f_output = fusion(p)
    print(o_output[0][0][0][0].item())
    print(f_output[0][0][0][0].item())
    print("Max abs diff: ", (o_output - f_output).abs().max().item())
    print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())


m = tv.models.resnet152(True)
m.eval()
print("Layer level test: ")
test_layer()

print("============================")
print("Module level test: ")
m = tv.models.resnet18(True)
m.eval()
test_net(m)

参考文献:

  • 1 https://blog.csdn.net/Haiqiang1995/article/details/90317657
  • 2 https://www.jb51.net/article/178552.htm
  • 3 https://zhuanlan.zhihu.com/p/49329030
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值