1. 协方差偏移
1.1 协方差偏移的概念
简而言之, 深度网络内部数据分布在训练过程中发生变化的现象。
统计机器学习的一个经典假设是“源域和目标域的数据分布是一致的”。协方差偏移就是分布不一致假设下的一个分支问题,他指的的源域和目标域的条件概率分布一致,但是边缘概率分布不一样,即
P
s
(
Y
∣
X
=
x
)
=
P
t
(
Y
∣
X
=
x
)
P_s(Y|X=x)=P_t(Y|X=x)
Ps(Y∣X=x)=Pt(Y∣X=x)
P
s
(
X
)
≠
P
t
(
X
)
P_s(X)\neq P_t(X)
Ps(X)=Pt(X)对于神经网络的各层输出来说,由于它们经过了层内操作的作用,分布显然与各层的输入分布是不同的,而且差异会随着网络深度的增大而增大,但是他们的标签确实仍然不变的,因此符合了协方差偏移的定义。
1.2 协方差偏移的影响
训练深度网络的时候,神经网络隐藏层参数更新会导致网络输出层输出数据的分布发生变化,如果不做归一化,那么每层网络输入的数据分布都是不同的。随着层数的增加,根据链式法则,这种偏移现象会逐渐被放大,高层的输入分布变化的会非常剧烈,使得高层需要不断的去适应底层的参数更新。为了保证网络参数的稳定性和收敛性,往往会选择比较小的学习率,同时参数初始化的好坏也明显影响训练出的模型精度,特别是在训练具有饱和非线性的网络,比如采用S或者双S激活函数的网络,例如LSTM、GRU。
简而言之,协方差偏移导致了神经元的输入数据不再是“独立同分布”,有三个缺点
- 上层参数需要不断的适应新的输入数据分布,降低了学习速度
- 下层输入的变化可能趋向于变大或者变小,导致上层落入饱和区(梯度消失和爆炸),使得学习过早停止
- 每次的更新都会影响到其它层,因此每层的参数更新策略需要尽可能的谨慎
2. 归一化的通用框架与基本思想
假设神经网络的输入是 X = ( x 1 , x 2 , ⋯ , x d ) X = (x_1,x_2,\cdots,x_d) X=(x1,x2,⋯,xd),通过运算,得到输出 y = f ( X ) y =f(X) y=f(X)。
由于协方差偏移的存在,X的分布可能相差很大,要解决独立同分布问题,最好的方法就是对于每一层的数据都进行白化操作,但是标准的白化操作代价高昂,由于反向传播,我们还希望白化操作是可微的以保证可以进行梯度的更新。
白化操作是一个重要的数据预处理步骤,一般包含两个目的
- 去除特征之间的相关性 ⇒ \Rightarrow ⇒ 独立
- 是的所有特征具有相同的均值和方差 ⇒ \Rightarrow ⇒ 同分布
因此,以BN为代表的归一化方法退而求其次,进行了简化的操作,基本思想是:在 x 输入网络之前,先对其进行 平移和伸缩 变换,将 x 的分布规范化在固定区间范围的标准分布。通用的变化框架如下
h
=
f
(
g
⋅
x
−
μ
σ
+
b
)
h =f(g\cdot \frac{x-\mu}{\sigma}+b)
h=f(g⋅σx−μ+b)(1)在上式中,
μ
\mu
μ 是平移参数(shift),
σ
\sigma
σ 是缩放参数(scale),通过这两个参数进行平移和缩放变换。
x
^
=
x
−
μ
σ
\hat{x}=\frac{x-\mu}{\sigma}
x^=σx−μ得到的数据符合均值为0,方差为1的标准分布。
(2)在上式中,
b
b
b 是再平移参数(re-shift),
g
g
g 是再缩放参数(re-scale)。将上一步得到的
x
^
\hat{x}
x^ 进一步变换为
y
=
g
⋅
x
^
+
b
y=g\cdot \hat{x} +b
y=g⋅x^+b最终得到的数据符合均值为
b
b
b,方差为
g
2
g^2
g2 的分布。
其中,再平移参数和再缩放参数是可学习的,这使得Normalization层能够尊重底层的学习结果,从而保证模型的表达能力不会因为归一化而下降。
标准的白化操作是独立同分布,通过Normalization后,变换为均值为 b b b,方差为 g 2 g^2 g2 的分布,这并不是严格的同分布,知识映射到了一个确定的区间范围
3. 常见的归一化方法
在上一节中,我们提出了Nomalization的通用公式:
h
=
f
(
g
⋅
x
−
μ
σ
+
b
)
h =f(g\cdot \frac{x-\mu}{\sigma}+b)
h=f(g⋅σx−μ+b)
首先,根据一个示意图来形象的表明BN、LN、IN和GN的区别,在图片中,HW是被合成的一个维度,这个是方便画出示意图,C和N各占一个维度。
3.1 Batch Nomalization
3.1.1 什么是 BN
BN的计算就是把每个通道的NHW单独拿出来归一化处理,其计算算法流程如下所示:
3.1.2 为什么使用 BN
为了加快网络收敛速度,解决网络的梯度消失与梯度爆炸问题
以sigmoid函数为例子,其函数图像如下所示
当x到了一定的大小,经过sigmoid函数的输出范围就很小了,如下所示
如果输入很大或者很小,那么梯度就会很小,梯度在反向传播中是网络权重学习的速率,就会出现下面问题
在深度网络中,如果网络的激活输出很大,其梯度就很小,学习速率就很慢。假设每层学习梯度都小于最大值0.25,网络有n层,因为链式求导的原因,第一层的梯度小于0.25的n次方,所以学习速率就慢,对于最后一层只需对自身求导1次,梯度就大,学习速率就快。
这会造成的影响是在一个很大的深度网络中,浅层基本不学习,权值变化小,后面几层一直在学习,结果就是,后面几层基本可以表示整个网络,失去了深度的意义。梯度爆炸也是同理。
3.1.3 怎么使用BN
假设一个batch有N个样本,每个样本通道数为C,高为H,宽为W。在求均值和方差时,将在N、H、W上进行操作,从而保留C的维度。具体来说,就是把第1个样本的第1个通道,加上第2个样本的第1个通道,……,加上第N个样本的第1个通道,求平均,得到通道1的均值,对所有的通道都进行操作,得到所有通道的均值和方差。
公式为
μ
n
(
x
)
=
1
N
H
W
∑
n
=
1
N
∑
h
=
1
H
∑
w
=
1
W
x
n
c
h
w
\mu_n(x)=\frac{1}{NHW}\sum_{n=1}^N\sum_{h=1}^H\sum_{w=1}^Wx_{nchw}
μn(x)=NHW1n=1∑Nh=1∑Hw=1∑Wxnchw
σ
n
(
x
)
=
1
N
H
W
∑
n
=
1
N
∑
h
=
1
H
∑
w
=
1
W
(
x
n
c
h
w
−
μ
c
(
x
)
)
2
+
ϵ
\sigma_n(x)=\sqrt{\frac{1}{NHW}\sum_{n=1}^N\sum_{h=1}^H\sum_{w=1}^W(x_{nchw}-\mu_c(x))^2+\epsilon}
σn(x)=NHW1n=1∑Nh=1∑Hw=1∑W(xnchw−μc(x))2+ϵ
上图中每一列表示一个样本,横向表示通道。
在训练的时候,均值和方差为每一个batch的均值和方差
在测试的时候,使用的均值和方差是全部训练数据的均值和方差,这个可以通过滑动平均的方式得到
为什么训练的时候不使用全部训练数据的均值和方差呢?这是因为使用全量的训练集和方差容易过拟合,对于BN来讲,其实就是对每一批数据进行归一化到一个相同的分布,而每一批数据的均值和方差会有一定的差别,而不是用固定的值,这个差别实际上能够增加模型的鲁棒性,也会在一定程度上减少过拟合。也正是因此,BN一般要求将训练集完全打乱,并用一个较大的batch值,否则,一个batch的数据无法较好得代表训练集的分布,会影响模型训练的效果。
x = torch.randn(10,3,5,5)*1000
x1 = x.transpose(0,1).contiguous().view(3,-1)
mu = x1.mean(dim=1).view(1,3,1,1)
std = x1.std(dim=1),view(1,3,1,1)
bn = (x-mu)/std
# bn = g*bn+b
3.2 Layer Nomalization
在章节3.1中我们提到BN不适用于batch较小的请款下,并且其对于RNN等动态的网络的效果也并不好,而Layer Nomalization很好地解决了这两个问题。
3.2.1 什么是LN
3.2.1.1 MLP中的LN
BN的两个缺点的产生原因均是因为计算归一化统计量时计算的样本数太少。LN是一个独立于batch size的算法,所以无论样本数多少都不会影响参与LN计算的数据量,它综合考虑一层所有维度的输入,计算该层的平均输入值和输入方差,然后用同一个规范化操作来转换各个维度的输入。先看MLP中的LN。设
H
H
H 是一层中隐层节点的数量,
l
l
l 是MLP的层数,我们可以计算LN的归一化统计量
μ
\mu
μ 和
σ
\sigma
σ :
μ
l
=
1
H
∑
i
=
1
H
a
i
l
\mu^l=\frac{1}{H}\sum_{i=1}^{H}a_i^l
μl=H1i=1∑Hail
σ
l
=
1
H
∑
i
=
1
H
(
a
i
l
−
μ
l
)
2
\sigma^l=\sqrt{\frac{1}{H}\sum_{i=1}^{H}(a_i^l-\mu^l)^2}
σl=H1i=1∑H(ail−μl)2注意上面统计量的计算是和样本数量没有关系的,它的数量只取决于隐层节点的数量,所以只要隐层节点的数量足够多,我们就能保证LN的归一化统计量足够具有代表性
a
^
l
=
a
l
−
μ
l
(
σ
l
)
2
+
ϵ
\hat{a}^l=\frac{a^l-\mu^l}{\sqrt{(\sigma^l)^2}+\epsilon}
a^l=(σl)2+ϵal−μl其中
ϵ
\epsilon
ϵ 是一个很小的小数,防止除0(论文中忽略了这个参数)。
同样进行再缩放和再平移
h
l
=
f
(
g
l
⋅
a
^
l
+
b
l
)
h^l = f(g^l \cdot \hat{a}^l +b^l)
hl=f(gl⋅a^l+bl)
3.2.1.2 RNN中的LN
在RNN中,我们可以非常简单的在每个时间片中使用LN,而且在任何时间片我们都能保证归一化统计量统计的是
H
H
H 个节点的信息。对于RNN时刻
t
t
t 时的节点,其输入是
t
−
1
t-1
t−1 时刻的隐层状态
h
t
−
1
h^{t-1}
ht−1 和
t
t
t 时刻的输入数据
x
t
x_t
xt ,可以表示为:
a
t
=
W
h
h
h
t
−
1
+
W
x
h
x
t
a^t = W_{hh}h^{t-1}+W_{xh}x^t
at=Whhht−1+Wxhxt 接着便可以在
a
t
a^t
at 上进行与上节中一样的归一化操作
μ
l
=
1
H
∑
i
=
1
H
a
i
l
\mu^l=\frac{1}{H}\sum_{i=1}^{H}a_i^l
μl=H1i=1∑Hail
σ
l
=
1
H
∑
i
=
1
H
(
a
i
l
−
μ
l
)
2
\sigma^l=\sqrt{\frac{1}{H}\sum_{i=1}^{H}(a_i^l-\mu^l)^2}
σl=H1i=1∑H(ail−μl)2
a
^
l
=
a
l
−
μ
l
(
σ
l
)
2
+
ϵ
\hat{a}^l=\frac{a^l-\mu^l}{\sqrt{(\sigma^l)^2}+\epsilon}
a^l=(σl)2+ϵal−μl
h
l
=
f
(
g
l
⋅
a
^
l
+
b
l
)
h^l = f(g^l \cdot \hat{a}^l +b^l)
hl=f(gl⋅a^l+bl)
3.2.2 为什么使用LN
3.2.2.1 batch size的约束
BN是按照样本数计算归一化统计量的,当样本数很少时,比如说只有4个。这四个样本的均值和方差便不能反映全局的统计分布息,所以基于少量样本的BN的效果会变得很差。
另外,在一些场景中,比如说硬件资源受限,在线学习等场景,BN是非常不适用的。
3.2.2.2 BN与RNN
RNN可以展开成一个隐藏层共享参数的MLP,随着时间片的增多,展开后的MLP的层数也在增多,最终层数由输入数据的时间片的数量决定,所以RNN是一个动态的网络。
在一个batch中,通常各个样本的长度都是不同的,当统计到比较靠后的时间片时,例如图中 t > 4 t>4 t>4 时,这时只有一个样本还有数据,基于这个样本的统计信息不能反映全局分布,所以这时BN的效果并不好。
另外如果在测试时我们遇到了长度大于任何一个训练样本的测试样本,我们无法找到保存的归一化统计量,所以BN无法运行。
3.2.3 怎么使用LN
Batch Normalization 的一个缺点是需要较大的 batchsize 才能合理估训练数据的均值和方差(横向计算),这导致内存很可能不够用,同时它也很难应用在训练数据长度不同的 RNN 模型上。Layer Normalization (LN) 的一个优势是不需要批训练,在单条数据内部就能归一化。
对于 x ∈ R N × C × H × W x \in R^{N\times C\times H\times W} x∈RN×C×H×W, LN 对每个样本的 C、H、W 维度上的数据求均值和标准差,保留 N 维度。
其均值和标准差公式为
μ
n
(
x
)
=
1
C
H
W
∑
c
=
1
C
∑
h
=
1
H
∑
w
=
1
W
x
n
c
h
w
\mu_n(x)=\frac{1}{CHW}\sum_{c=1}^C\sum_{h=1}^H\sum_{w=1}^Wx_{nchw}
μn(x)=CHW1c=1∑Ch=1∑Hw=1∑Wxnchw
σ
n
(
x
)
=
1
C
H
W
∑
c
=
1
C
∑
h
=
1
H
∑
w
=
1
W
(
x
n
c
h
w
−
μ
n
(
x
)
)
2
+
ϵ
\sigma_n(x)=\sqrt{\frac{1}{CHW}\sum_{c=1}^C\sum_{h=1}^H\sum_{w=1}^W(x_{nchw}-\mu_n(x))^2+\epsilon}
σn(x)=CHW1c=1∑Ch=1∑Hw=1∑W(xnchw−μn(x))2+ϵ
x = torch.randn(10,3,5,5)*1000
x1 = x.contiguous().view(10,-1)
mu = x1.mean(dim=1).view(10,1,1,1)
std = x1.std(dim=1),view(10,1,1,1)
bn = (x-mu)/std
# bn = g*bn+b
3.3 Instance Normalization
Instance Normalization (IN) 最初用于图像的风格迁移。作者发现,在生成模型中, feature map 的各个 channel 的均值和方差会影响到最终生成图像的风格,因此可以先把图像在 channel 层面归一化,然后再用目标风格图片对应 channel 的均值和标准差“去归一化”,以期获得目标图片的风格。IN 操作也在单个样本内部进行,不依赖 batch。
对于
x
∈
R
N
×
C
×
H
×
W
x \in R^{N\times C\times H\times W}
x∈RN×C×H×W, LN 对每个样本的 H、W 维度上的数据求均值和标准差,保留 N、C 维度,也就是说,它只在channel内部求平均值和标准差,公式为:
μ
n
(
x
)
=
1
H
W
∑
h
=
1
H
∑
w
=
1
W
x
n
c
h
w
\mu_n(x)=\frac{1}{HW}\sum_{h=1}^H\sum_{w=1}^Wx_{nchw}
μn(x)=HW1h=1∑Hw=1∑Wxnchw
σ
n
(
x
)
=
1
H
W
∑
h
=
1
H
∑
w
=
1
W
(
x
n
c
h
w
−
μ
n
c
(
x
)
)
2
+
ϵ
\sigma_n(x)=\sqrt{\frac{1}{HW}\sum_{h=1}^H\sum_{w=1}^W(x_{nchw}-\mu_{nc}(x))^2+\epsilon}
σn(x)=HW1h=1∑Hw=1∑W(xnchw−μnc(x))2+ϵ
x = torch.randn(10,3,5,5)*1000
x1 = x.view(30,-1)
mu = x1.mean(dim=1).view(10,3,1,1)
std = x1.std(dim=1),view(10,3,1,1)
bn = (x-mu)/std
# bn = g*bn+b
3.4 Group Normalization
Group Normalization (GN) 适用于占用显存比较大的任务,例如图像分割。对这类任务,可能 batchsize 只能是个位数,再大显存就不够用了。而当 batchsize 是个位数时,BN 的表现很差,因为没办法通过几个样本的数据量,来近似总体的均值和标准差。GN 也是独立于 batch 的,它是 LN 和 IN 的折中。
GN 计算均值和标准差时,把每一个样本 feature map 的 channel 分成 G 组,每组将有 C/G 个 channel,然后将这些 channel 中的元素求均值和标准差。各组 channel 用其对应的归一化参数独立地归一化。
μ n ( x ) = 1 ( C / G ) H W ∑ c = g C / G ( g + 1 ) C / G ∑ h = 1 H ∑ w = 1 W x n c h w \mu_n(x)=\frac{1}{(C/G)HW}\sum_{c=gC/G}^{(g+1)C/G}\sum_{h=1}^H\sum_{w=1}^Wx_{nchw} μn(x)=(C/G)HW1c=gC/G∑(g+1)C/Gh=1∑Hw=1∑Wxnchw σ n ( x ) = 1 ( C / G ) H W ∑ c = g C / G ( g + 1 ) C / G ∑ h = 1 H ∑ w = 1 W ( x n c h w − μ n g ( x ) ) 2 + ϵ \sigma_n(x)=\sqrt{\frac{1}{(C/G)HW}\sum_{c=gC/G}^{(g+1)C/G}\sum_{h=1}^H\sum_{w=1}^W(x_{nchw}-\mu_{ng}(x))^2+\epsilon} σn(x)=(C/G)HW1c=gC/G∑(g+1)C/Gh=1∑Hw=1∑W(xnchw−μng(x))2+ϵ
x = torch.randn(10,20,5,5)*1000
# 分为4个组
x1 = x.view(10,4,-1)
mu = x1.mean(dim=-1).view(10,4,-1)
std = x1.std(dim=-1),view(10,4,-1)
x1_norm = (x-mu)/std
bn = x1_norm.reshape(10,20,5,5)
# bn = g*bn+b
4 参考文献
文献1:Juliuszh,详解深度学习中的Normalization,BN/LN/WN,https://zhuanlan.zhihu.com/p/33173246
文献2:时光碎了天, 深度学习中的五种归一化(BN、LN、IN、GN和SN)方法简介,https://blog.csdn.net/u013289254/article/details/99690730
文献3:Dong,BN、LN、IN、GN的简介,https://zhuanlan.zhihu.com/p/91965772
文献4:大师兄,模型优化之Layer Normalization,https://zhuanlan.zhihu.com/p/54530247