写作目的
BatchNorm工作中一直在用,之前也看了好几遍paper,一直没有理解透彻。最近OpenMMLab发布了BN的pytorch源码解读,跟着其理了一遍BN的实现,有了新的理解,特此记录。
BatchNorm初衷
缓解造成深度神经网络训练难的一个问题—— internal covariate shift(在训练过程中,由于前面层的参数改变导致每层的输入分布在不断改变)。
BatchNorm原理
BatchNorm in Training
上述公式为原Paper中给出的公式,这里有几个比较困惑的问题:
(1)为什么归一化之后又要加gamma和beta,破坏归一化分布?
该问题在原论文(Sec3)中能找到答案。直接归一化会丢失一些学习到的特征,为了保证插入的模块能完成恒等变换,能将归一化后的分布恢复为原始学习到的分布,加入了参数gamma和beta。在极端情况gamma等于方差,beta等于均值的时候BN能恢复原始的输入。
个人理解,这里在保留原始特征和归一化之间做了一个折中,使得BN输出的特征既能尽可能保持之前网络学习到的特征,同时也能尽可能归一化特征分布。
(2) 为什么均值和方差在除了channel之外的其他维度上求。
对应于卷积计算过程,输入的channel是多大,其对应的feature map就有多深,即每个channel对应于一个二维的核(cs231n)。同理每个channel对应一个mean和variance统计值,也对应一个gamma和一个bata。
BatchNorm in Inference
上述公式为原paper中给出的公式,有一个比较困惑的问题:
(1)为什么inference时的公式和training时的公式看上去不一样?
将该式子变换为gamma和beta的公式,如下所示:
现在看上去就和训练时候是一样了。但是由于测试时候均值、方差以及参数gamma和beta都不用计算,所以写成了原paper中所示的形式,省去了除法计算。
BatchNorm计算
参考OpenMMLab对PyTorch中BatchNorm实现的解读,手动计算BatchNorm,过程如下所示。
首先明确一下BatchNorm的输入是什么——在卷积层之后的BatchNorm输入是卷积层的output,即一个shape为NCHW的tensor。假设其输入为以下N=2,C=3,H=2,W=2的随机tensor。
>>> inputs = torch.randn(2, 3, 2, 2)
>>> inputs
tensor([[[[-0.8030, 0.2264],
[-0.8409, -0.9273]],
[[-0.5455, -0.7653],
[ 0.3283, -0.1631]],
[[-0.1568, 1.5355],
[ 1.7826, 0.8605]]],
[[[-1.9135, 1.2265],
[ 1.5482, -0.1545]],
[[ 2.3411, 1.5531],
[ 0.9359, 1.1492]],
[[-0.2346, -0.7456],
[ 0.0338, -0.8034]]]])
即第一个样本的卷积层output为:
[[[-0.8030, 0.2264],
[-0.8409, -0.9273]],
[[-0.5455, -0.7653],
[ 0.3283, -0.1631]],
[[-0.1568, 1.5355],
[ 1.7826, 0.8605]]]
第二个样本的卷积层output为:
[[[-1.9135, 1.2265],
[ 1.5482, -0.1545]],
[[ 2.3411, 1.5531],
[ 0.9359, 1.1492]],
[[-0.2346, -0.7456],
[ 0.0338, -0.8034]]]
(1)在N的维度上求均值,得到两个样本的平均特征,shape为CHW,结果如下所示:
tensor([[[-1.3582, 0.7265],
[ 0.3537, -0.5409]],
[[ 0.8978, 0.3939],
[ 0.6321, 0.4930]],
[[-0.1957, 0.3950],
[ 0.9082, 0.0286]]])
此处即对应位置的两个数相加除以2,如左上角-1.3582 = (-0.8030-1.9135)/2
(2)对于该平均特征,在H的维度上求均值,shape为C*W,结果如下所示:
tensor([[-0.5023, 0.0928],
[ 0.7650, 0.4435],
[ 0.3562, 0.2118]])
此处即对没一列求平均,例如左上角-0.5023 = (-1.3582+0.3537)/2
(3)在W的维度上继续求平均,得到最终的inputs_mean,为一个C维的tensor,结果如下所示:
inputs_mean = tensor([-0.2048, 0.6042, 0.2840])
此处即对每一行求平均,例如第一个数 -0.2048= (-0.5023 + 0.0928)/2
(4)按照上述三步同样的方法求得方差,如下所示:
inputs_var = tensor([1.1893, 1.0233, 0.8641])#unbiased=False,有偏
(5)根据求出的均值和方差对输入进行归一化:
inputs = inputs - inputs_mean[None, ..., None, None]
inputs = inputs / torch.sqrt(inputs_var[None, ..., None, None] + eps)
(6)对归一化后的结果乘以gamma并加上beta:
inputs = inputs * bn_weight[..., None, None] + bn_bias[..., None, None]
(7)更新均值running_mean 和方差running_var,用于推理。
running_mean = running_mean * (1 - momentum) + momentum * inputs_mean
running_var = running_var * (1 - momentum) + momentum * inputs_var * n / (n - 1)
BatchNorm的优点
结合以上分析,进一步理解BatchNorm的优点:
(1)允许使用更大的学习率,使得训练更快。
原因:BatchNorm能阻止大学习率使得参数发生大的变化后,输出落入非线性函数的饱和区域;另外大的学习率相当于放大了参数,这样会增加反传的梯度,造成梯度爆炸,在采用了BatchNorm之后,反向传播将不会受到参数缩放的影响。(sec3.3)
(2)有正则化的效果,一些情况下可以去除drop out。
原因:采用BatchNorm的时候,网络会同时看到一个batch样本的统计量(均值、方差),而不是一个样本。达到了正则化的效果。一个batch样本采样约随机,正则化效果越好。
(3)使用sigmoid等非线性层时的梯度不容易消失。
原因:BatchNomr能将非线性层输入分布更多的拉到0附近(非线性层梯度较大的区间)。
加速BN网络训练的方法
在原论文(sec4.2.1)中,作者提出了一些让BN工作得更好得技巧,如下所示:
(1)增加学习率。
(2)移除Dropout。
(3)减少权重的L2正则化,比如除以5。
(4)加快学习率衰减的step。
(5)更彻底的Shuffle训练样本。
(6)减小 photometric distortions(光度畸变)。
(7)移除 Local Response Normalization。
BN合并到卷积层
在模型部署时,往往需要把BN合并到卷积层,已加快推理。合并BN的公式推导如下: