一. Batch Normalization所解决的问题
Batch Normalization,是针对“Internal Covariate Shift”问题提出的,一种cross-batch的数据标准化方法,已被实践证明能促进深度学习中的BP梯度反向传递流的正常工作,从而已成为诸多深度学习网络架构的常用层。
在传统机器学习中,输入数据各特征维度的标准化(转换为服从标准正态分布)的预处理过程,对于涉及矩阵运算以及梯度求导的机器学习算法大有裨益,可便于矩阵运算,并加速算法迭代收敛。
而在深度学习中,随着网络架构的不断加深,也存在类似的问题。对于原始数据,其被称为“Outer Covariate Shift”,同样可以通过上述的标准化处理保证输入数据的同分布。但在结构内部,由于网络架构中的卷积、投影和非线性激活等操作,导致每层网络接受到的输入值呈现非同分布的趋势(也就是“Internal Covariate Shift”),这给梯度反向传递带来了如下困难:
(1)网络必须学习去适应这种非同分布的数据特点,加重了网络学习的负担;
(2)随着网络的加深,隐层的数据会主键发生偏移,容易落入激活函数的饱和区域,导致梯度消失或梯度爆炸问题,加大了网络学习的难度;
(3)由于(2)的存在,导致整个网络模型在学习率的选取、权重初始化上必须谨慎选择。
而Batch Normalization通过在每个层输入数据中,强制的进行标准化处理,使数据分布满足统一的标准正态分布,经验证能够有效的解决上述问题。
二. Batch Normalization的计算流程
Batch Normalization分训练阶段和使用阶段。
2.1 训练阶段
在训练阶段,其对每个mini-batch中各layer中非线性激活前的各channel数据进行标准化化处理。
假设对于某神经网络,其第 L L L隐藏层为 σ ( W X + b ) \sigma(WX+b) σ(WX+b),其中 X X X为输入数据(即第 L − 1 L-1 L−1层的输出数据), σ \sigma σ为非线性激活函数。Batch Normalization一般作用在 W X + b WX+b WX+b之后,非线性激活之前!
假设在该层,mini-batch训练数据中,某个channel上的所有数据为
x
1
,
x
2
,
.
.
.
x
m
{x_1,x_2,...x_m}
x1,x2,...xm,则其具体算法流程为:
其对应的解释为:
(1)所有数据的均值;
(2)所有数据的方差值;
(3)利用上面计算能得到的均值和方式,对所有数据进行标准化处理,使其落入标准正态分布内;
(4)设置两个可学习的参数
γ
、
β
\gamma、\beta
γ、β,分别对上面的标准结果进行缩放(scale)和偏移(shift),从而得到BN之后的最终结果。
上述计算流程有两个需要注意和解释的点:
(1)参数
γ
、
β
\gamma、\beta
γ、β的意义:对数据进行标准化的做法可能太强,所以通过设置这两个参数来实现标准化的逆函数,使得模型寻找原始数据和标准化处理的一个折中。如果模型通过学习,认为无需进行标准化,则这两个参数保留了再转化为原始输入数据的能力。
(2)各mini-batch的均值和方差应予以保留,以便在使用阶段使用。
2.2 使用阶段
在使用阶段,对数据的推断是逐样本的,而非逐batch的。这与Batch Normalization的cross-batch设计所不符。因此,我们需要从在训练阶段的各mini-batch的均值和方差数据中得到一个具有统计意义的值,用于推断:
然后利用训练好的各channel的
γ
、
β
\gamma、\beta
γ、β参数,对样本数据进行变换:
在实践中,统计手法往往通过移动平均的方式进行统计。
2.3 Batch Normalization的代码示例
下面根据上面的算法流程介绍,以图像的四维数据为例(batch×channel×height×width),给出pytorch的实现方式:
import torch
import torch.nn as nn
class MyBatchNorm(nn.Module):
def __init__(self, moment=0.9, eps = 1e-5, train=True):
super(MyBatchNorm, self).__init__()
self.initial = False
self.gamma = None # 缩放因子
self.beta = None # 移动因子
self.running_mean = None # 移动平均数
self.running_var = None # 移动方差
assert 0 <= moment <= 1
self.moment = moment # 移动稀疏
self.eps = torch.tensor(eps)
def forward(self, inputs):
assert inputs.dim() == 4 # (N, C, H, W)
if self.train:
if not self.initial: # 实现参数初始化
channel_size = inputs.shape[1]
self.gamma = nn.init.normal_(torch.zeros(channel_size, requires_grad=True), 0, 1) # (C)
self.beta = nn.init.normal_(torch.zeros(channel_size, requires_grad=True), 0, 1) # (C)
self.running_mean = torch.zeros(channel_size) # (C)
self.running_var = torch.zeros(channel_size) # (C)
self.initial = True
mean_x = torch.mean(inputs, dim=(0, 2, 3)) # (C)
var_x = torch.var(inputs - mean_x.view(1, -1, 1, 1), dim=(0, 2, 3)) # (C)
self.running_mean = self.running_mean * self.moment + mean_x * (1 - self.moment) # 计算移动平均数
self.running_var = self.running_var * self.moment + var_x * (1 - self.moment) # 计算移动平方数
inputs = (inputs - mean_x.view(1, -1, 1, 1)) / torch.sqrt(var_x.view(1, -1, 1, 1) + self.eps) # BN计算
inputs = self.gamma.view(1, -1, 1, 1) * inputs + self.beta.view(1, -1, 1, 1)
else:
assert self.initial
inputs = self.gamma.view(1, -1, 1, 1)/torch.sqrt(self.running_var.view(1, -1, 1, 1) + self.eps) * inputs + (self.beta.view(1, -1, 1, 1) -
self.gamma.view(1, -1, 1, 1)*self.running_mean.view(1, -1, 1, 1)/torch.sqrt(self.running_var.view(1, -1, 1, 1) + self.eps)) # BN推断
return inputs
三. Batch Normalization的优点和缺点
Batch Normalization很好的解决了“Internal Covariate Shift”问题,实践证明其在下面方面有很好的效果:
(1)保证了激活函数的有效性,从而有利于导数反向传播流的传递。通过Batch Normalization的数据被压缩在0-1的标准正态分布内,使得数据满足了同分布的要求,同时有效防止了数据落入后续激活函数的饱和区间内,进而规避了梯度爆炸和消失问题,保证了模型迭代的收敛,提升了模型训练的速度。
(2)对于学习率和权重初始化提供了更宽松的可能性。
(3)提供了类似于Dropout的正则化效果,从而可以防止模型的过拟合。
其缺点也非常明显:
(1)Batch Normalization是一种cross-batch的标准化过程,其训练阶段依托于每个mini-batch内数据的统计计算,而使用阶段仅作用于单个样本上,这使得其在训练和使用阶段并不一致;
(2)Batch Normalization的效果严重依赖于Batch的尺寸,而在实际生产中,取决于内存大小和运算速度的要求,无法保证Batch的尺寸,进而Batch Normalization的效果。
(3)由于文本序列长度的不一致,Batch Normalization并不适用于RNN等网络。
四. 衍生出的各类Normalization
针对Batch Normalization的特点和不足,在cross-batch的标准化外,人们又分别针对数据的不同维度,提出了各式各样的Normalization过程,主要包括:
4.1 Layer Normalization
Layer Normalization是作用于各样本上的,cross-layer(即便cross-channel)的规范化方式,其无需依托于mini-batch上的局部数据,每个训练和推断样本均可利用其本身的数据的进行规范化操作,所以其使用更方便,在Tranformer中就广泛采用了该架构。
以四个维度的图像(batch×channel×height×width)
数据为例,就是在channel×height×width上进行规范化。
4.2 Instance Normalization
Layer Normalization是作用于各样本的各个layer/channel上的规范化方式,其同样作用于各个样本,且粒度更细。
以四个维度的图像(batch×channel×height×width)
数据为例,就是在每个height×width上进行规范化。
4.4 Group Normalization
Layer Normalization的粒度太粗,而Instance Normalization的粒度又太细, Group Normalization做到了平衡,即在一定数目的channel上进行规范化。
以四个维度的图像(batch×channel×height×width)
数据为例,就是在每个N_channel×height×width上进行规范化。
4.5 Switchable Normalization
Switchable Normalization是Batch Normalization、Layer Normalization和Instance Normalization的综合,其每层的规范化操作是上述三个规范化的加权和,加权系数为可学习的参数。
上述各种规范化的区别可通过下图进行直观的显示:
【Reference】