批归一化处理(Batch Normalization, BN层)通常用于深层的神经网络中,其作用是对网络中某层特征进行标准化处理,其目的是解决深层神经网络中的数值不稳定的问题,是的同批次的各个特征分不相近,网络更加容易训练。
BN层一般是放在仿射变换,FC或CONV,层后,在非线性激活单元之前(也有一些实现放在线性激活单元后,但是普遍用法是前者)。
深层网络的数值不稳定的问题是,随着网络层数加深,训练过程中参数的更新容易造成靠近输出层的特征输出产生剧烈地变化,不利于训练出有效的神经网络。
BN层原理
BN层常见有针对1d特征(全连接层后面的)以及针对2d特征(2d卷积层后面的)两种(这里暂不讨论3d卷积)。 依照管理,我们从1d开始理解。
BatchNorm1d
假设在训练过程中,一批次的全连接层输出的特征为 x , 张量的shape为 (batch size=m, feature dim=d),我们如下表示输入:
x
=
[
x
(
1
)
,
x
(
2
)
,
.
.
.
,
x
(
m
)
]
x = [x^{(1)}, x^{(2)}, ..., x^{(m)}] \\
x=[x(1),x(2),...,x(m)]
- 首先,我们需要沿着batch的维度 对该小批量的特征求其均值和方差:
μ B = 1 m ∑ i = 1 m x ( i ) σ B 2 = 1 m ∑ i = 1 m ( x ( i ) − μ B ) 2 \mu_{B} = \frac{1}{m}\sum_{i=1}^{m}{x^{(i)}}\\ \sigma_B^2 = \frac{1}{m}\sum_{i=1}^{m}{(x^{(i)} - \mu_B)^2} μB=m1i=1∑mx(i)σB2=m1i=1∑m(x(i)−μB)2
上式中,需要注意的是,求方差时的平方符号是针对逐元素求平方,而非求范数。所以,特征的每个channel的均值和方差时独立的,均值和方差的shape应该是feature dim = d,或者为便于计算可以认为是(1, d)。
- 接着,我们要对特征进行标注化,并且计算BN层的输出。
x ^ ( i ) = x ( i ) − μ B σ B 2 + ϵ \hat{x}^{(i)} = \frac{x^{(i)} - \mu_{B}}{\sqrt{\sigma_B^2+\epsilon}} x^(i)=σB2+ϵx(i)−μB
上式这一步是对特征进行标准化,式中的epsilon为一个很小的常量,为了避免初零的异常。 然后,就需要计算BN层的输出了。在训练过程中,BN层有两个可以训练的参数,gamma和beta,通过这两个参数与归一化特征计算得到输出:
y
(
i
)
=
γ
⊙
x
^
(
i
)
+
β
y
=
[
y
(
1
)
,
y
(
2
)
,
.
.
.
,
y
(
m
)
]
y^{(i)} = \gamma\odot\hat{x}^{(i)}+\beta \\ \space y = [y^{(1)},y^{(2)},...,y^{(m)}]
y(i)=γ⊙x^(i)+β y=[y(1),y(2),...,y(m)]
于是,BN层的一个forward过程就结束了,其输出特征与输入特征shape相同。并且,一个BN层的可训练参数量与fearure dim有关,为 2*d ,即gamma和beta的参数。
BatchNorm2d
在卷积层后面的BN与1d情况的BN层是很类似的,同样是沿着batch的维度求均值和方差。并且,它的参数大小也仅仅与feature dim有关,也为2*d。 但是,有一点需要注意的是,在求均值和方差的时候,实际上其不仅是沿着batch的维度求取,在每个channel上的宽度和高度方向也求取均值。
假设,BN层上一层的卷积层的输出特征为shape [b, c, h, w] 为 [2, 2, 2, 2],如上图。我们可能会以为,按照1d的情形推理,在求均值时沿着batch的维度进行求取,那么它的均值和方差的shape应该是[1, 2, 2, 2],得到如下图的结果:
但是,实际上,均值和方差的shape是[1, 2, 1, 1]。也就是说,求取均值和方差时,是沿着batch, h, w三个维度进行的,只保证每个channel的统计值是独立的,所以求得均值和方差:
笔者期初也困惑其原因,觉得这样会破坏输入特征的拓扑结构。但是查看了一些框架的源码,确实是按照上文描述的方式。注意到这一点,详细其他地方大家可以按照1d的情形进行推演。
预测时的均值和方差
在模型预测时,我们往往希望得到稳定的输出。所以,无论预测的输入batch size为多少,我们都希望BN层始终用同一个mean和var,而不该根据输入进行变化。在实现的时候,一般采用一阶指数平滑算法实现,其原理如下。
假设一个时间序列的长度为n(在此处,每个序列的时刻为训练的一个最小步骤),我们定义一个时间序列 S = s1, s2, …, sn; 拟合序列T = t1, t2, …, tn. 这里的时间序列是每个训练迭代计算得到的mean和var,拟合序列是每个迭代,通过平滑算法计算的mean和var。那么,对于一阶指数平滑,有下式:
t
i
=
α
s
i
+
(
1
−
α
)
t
i
−
1
t_i = \alpha s_i +(1-\alpha)t_{i-1}
ti=αsi+(1−α)ti−1
在进行算法前,要定义初始值t0,一般可以定义时间序列前三项的均值,也可初始为其他值。利用该算法,在每次迭代时,可以计算一个mean和var值,BN层需要将改值保存,用于预测的时候使用。 这两个值的参数量也分别与feature dim一致,它们是不可训练的参数。
利用torch自定义BN层
了解原理后,我们可以基于pytorch提供的框架,自己定义BN层,熟悉其操作过程,代码如下:
import torch
import torch.nn as nn
"""
barch normalization 操作
"""
def batch_norm(is_training, X, gamma, beta, moving_mean,
moving_var, eps, momentum):
if not is_training:
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
feature_shape = len(X.shape)
assert feature_shape in (2, 4)
if feature_shape == 2:
# 全连接层后面的BN
mean = X.mean(dim=0)
var = ((X - mean) ** 2).mean(dim=0)
else:
# 卷积层后面的BN
mean = X.mean(dim=(0, 2, 3), keepdim=True)
var = ((X-mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
X_hat = (X-mean) / torch.sqrt(var + eps)
print(mean.shape)
print(var.shape)
print(X.shape)
# 一阶指数平滑
moving_mean = momentum * moving_mean + (1. - momentum) * mean
moving_var = momentum * moving_var + (1. -momentum) * var
Y = gamma * X_hat + beta
return Y, moving_mean, moving_var
class _BatchNorm(nn.Module):
"""
BN层
"""
def __init__(self, num_features, num_dims, momentum):
super(_BatchNorm, self).__init__()
assert num_dims in (2, 4)
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.zeros(shape)
self.momentum = momentum
def forward(self, X):
Y, self.moving_mean, self.moving_var = batch_norm(
self.training, X, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=self.momentum
)
return Y
class BatchNorm1d(_BatchNorm):
def __init__(self, num_features, momentum=0.9):
super().__init__(num_features, 2, momentum)
class BatchNorm2d(_BatchNorm):
def __init__(self, num_features, momentum=0.9):
super().__init__(num_features, 4, momentum)
测试一下代码: