Batch Normalization及各类衍生的Normalization

一. Batch Normalization所解决的问题

Batch Normalization,是针对“Internal Covariate Shift”问题提出的,一种cross-batch的数据标准化方法,已被实践证明能促进深度学习中的BP梯度反向传递流的正常工作,从而已成为诸多深度学习网络架构的常用层。

在传统机器学习中,输入数据各特征维度的标准化(转换为服从标准正态分布)的预处理过程,对于涉及矩阵运算以及梯度求导的机器学习算法大有裨益,可便于矩阵运算,并加速算法迭代收敛。

而在深度学习中,随着网络架构的不断加深,也存在类似的问题。对于原始数据,其被称为“Outer Covariate Shift”,同样可以通过上述的标准化处理保证输入数据的同分布。但在结构内部,由于网络架构中的卷积、投影和非线性激活等操作,导致每层网络接受到的输入值呈现非同分布的趋势(也就是“Internal Covariate Shift”),这给梯度反向传递带来了如下困难:
(1)网络必须学习去适应这种非同分布的数据特点,加重了网络学习的负担;
(2)随着网络的加深,隐层的数据会主键发生偏移,容易落入激活函数的饱和区域,导致梯度消失或梯度爆炸问题,加大了网络学习的难度;
(3)由于(2)的存在,导致整个网络模型在学习率的选取、权重初始化上必须谨慎选择。

而Batch Normalization通过在每个层输入数据中,强制的进行标准化处理,使数据分布满足统一的标准正态分布,经验证能够有效的解决上述问题。

二. Batch Normalization的计算流程

Batch Normalization分训练阶段和使用阶段。

2.1 训练阶段

在训练阶段,其对每个mini-batch中各layer中非线性激活前的各channel数据进行标准化化处理。

假设对于某神经网络,其第 L L L隐藏层为 σ ( W X + b ) \sigma(WX+b) σ(WX+b),其中 X X X为输入数据(即第 L − 1 L-1 L1层的输出数据), σ \sigma σ为非线性激活函数。Batch Normalization一般作用在 W X + b WX+b WX+b之后,非线性激活之前!

假设在该层,mini-batch训练数据中,某个channel上的所有数据为 x 1 , x 2 , . . . x m {x_1,x_2,...x_m} x1,x2,...xm,则其具体算法流程为:
在这里插入图片描述
其对应的解释为:
(1)所有数据的均值;
(2)所有数据的方差值;
(3)利用上面计算能得到的均值和方式,对所有数据进行标准化处理,使其落入标准正态分布内;
(4)设置两个可学习的参数 γ 、 β \gamma、\beta γβ,分别对上面的标准结果进行缩放(scale)和偏移(shift),从而得到BN之后的最终结果。

上述计算流程有两个需要注意和解释的点:
(1)参数 γ 、 β \gamma、\beta γβ的意义:对数据进行标准化的做法可能太强,所以通过设置这两个参数来实现标准化的逆函数,使得模型寻找原始数据和标准化处理的一个折中。如果模型通过学习,认为无需进行标准化,则这两个参数保留了再转化为原始输入数据的能力。
(2)各mini-batch的均值和方差应予以保留,以便在使用阶段使用。

2.2 使用阶段

在使用阶段,对数据的推断是逐样本的,而非逐batch的。这与Batch Normalization的cross-batch设计所不符。因此,我们需要从在训练阶段的各mini-batch的均值和方差数据中得到一个具有统计意义的值,用于推断:
在这里插入图片描述
然后利用训练好的各channel的 γ 、 β \gamma、\beta γβ参数,对样本数据进行变换:
在这里插入图片描述
在实践中,统计手法往往通过移动平均的方式进行统计。

2.3 Batch Normalization的代码示例

下面根据上面的算法流程介绍,以图像的四维数据为例(batch×channel×height×width),给出pytorch的实现方式:

import torch
import torch.nn as nn

class MyBatchNorm(nn.Module):
    def __init__(self, moment=0.9, eps = 1e-5, train=True):
        super(MyBatchNorm, self).__init__()
        self.initial = False
        self.gamma = None          # 缩放因子
        self.beta = None          # 移动因子
        self.running_mean = None    # 移动平均数
        self.running_var = None          # 移动方差
        assert 0 <= moment <= 1
        self.moment = moment     # 移动稀疏
        self.eps = torch.tensor(eps)    

    def forward(self, inputs):
        assert inputs.dim() == 4    # (N, C, H, W)
        if self.train:
            if not self.initial:     # 实现参数初始化
                channel_size = inputs.shape[1]
                self.gamma = nn.init.normal_(torch.zeros(channel_size, requires_grad=True), 0, 1)    # (C)
                self.beta = nn.init.normal_(torch.zeros(channel_size, requires_grad=True), 0, 1)     # (C)
                self.running_mean = torch.zeros(channel_size)     # (C)
                self.running_var = torch.zeros(channel_size)      # (C)
                self.initial = True

            mean_x = torch.mean(inputs, dim=(0, 2, 3))   # (C)
            var_x = torch.var(inputs - mean_x.view(1, -1, 1, 1), dim=(0, 2, 3))   # (C)
            self.running_mean = self.running_mean * self.moment + mean_x * (1 - self.moment)   # 计算移动平均数
            self.running_var = self.running_var * self.moment + var_x * (1 - self.moment)    # 计算移动平方数

            inputs = (inputs - mean_x.view(1, -1, 1, 1)) / torch.sqrt(var_x.view(1, -1, 1, 1) + self.eps)   # BN计算
            inputs = self.gamma.view(1, -1, 1, 1) * inputs + self.beta.view(1, -1, 1, 1)

        else:
            assert self.initial
            inputs = self.gamma.view(1, -1, 1, 1)/torch.sqrt(self.running_var.view(1, -1, 1, 1) + self.eps) * inputs + (self.beta.view(1, -1, 1, 1) - 
                     self.gamma.view(1, -1, 1, 1)*self.running_mean.view(1, -1, 1, 1)/torch.sqrt(self.running_var.view(1, -1, 1, 1) + self.eps))  # BN推断
        return inputs
三. Batch Normalization的优点和缺点

Batch Normalization很好的解决了“Internal Covariate Shift”问题,实践证明其在下面方面有很好的效果:

(1)保证了激活函数的有效性,从而有利于导数反向传播流的传递。通过Batch Normalization的数据被压缩在0-1的标准正态分布内,使得数据满足了同分布的要求,同时有效防止了数据落入后续激活函数的饱和区间内,进而规避了梯度爆炸和消失问题,保证了模型迭代的收敛,提升了模型训练的速度。

(2)对于学习率和权重初始化提供了更宽松的可能性。

(3)提供了类似于Dropout的正则化效果,从而可以防止模型的过拟合。

其缺点也非常明显:

(1)Batch Normalization是一种cross-batch的标准化过程,其训练阶段依托于每个mini-batch内数据的统计计算,而使用阶段仅作用于单个样本上,这使得其在训练和使用阶段并不一致;

(2)Batch Normalization的效果严重依赖于Batch的尺寸,而在实际生产中,取决于内存大小和运算速度的要求,无法保证Batch的尺寸,进而Batch Normalization的效果。

(3)由于文本序列长度的不一致,Batch Normalization并不适用于RNN等网络。

四. 衍生出的各类Normalization

针对Batch Normalization的特点和不足,在cross-batch的标准化外,人们又分别针对数据的不同维度,提出了各式各样的Normalization过程,主要包括:

4.1 Layer Normalization

Layer Normalization是作用于各样本上的,cross-layer(即便cross-channel)的规范化方式,其无需依托于mini-batch上的局部数据,每个训练和推断样本均可利用其本身的数据的进行规范化操作,所以其使用更方便,在Tranformer中就广泛采用了该架构。

以四个维度的图像(batch×channel×height×width)
数据为例,就是在channel×height×width上进行规范化。

4.2 Instance Normalization

Layer Normalization是作用于各样本的各个layer/channel上的规范化方式,其同样作用于各个样本,且粒度更细。

以四个维度的图像(batch×channel×height×width)
数据为例,就是在每个height×width上进行规范化。

4.4 Group Normalization

Layer Normalization的粒度太粗,而Instance Normalization的粒度又太细, Group Normalization做到了平衡,即在一定数目的channel上进行规范化。

以四个维度的图像(batch×channel×height×width)
数据为例,就是在每个N_channel×height×width上进行规范化。

4.5 Switchable Normalization

Switchable Normalization是Batch Normalization、Layer Normalization和Instance Normalization的综合,其每层的规范化操作是上述三个规范化的加权和,加权系数为可学习的参数。

上述各种规范化的区别可通过下图进行直观的显示:
在这里插入图片描述

【Reference】

  1. Batch Normalization论文
  2. Layer Normalizaiton论文
  3. Instance Normalization论文
  4. Group Normalization论文
  5. Switchable Normalization论文
  6. 各种Normalization的介绍及其numpy实现
  7. 深入理解Batch Normalization
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值