主要贡献:提出归一化方法 RMSNorm,并通过实验证明归一化方法 re-scale 具有重要作用,re-center 收效甚微,基于此,放弃 re-center 步骤,仅保留 re-scale,降低了计算复杂度,提升了训练效率
论文:https://arxiv.org/abs/1910.07467
研究背景
- 提升模型的训练效率是一个长期存在的挑战,深度学习的多年发展中,优化器的改进、损失函数的优化、数据预处理步骤、模型参数初始化、训练权重归一化等,众多的优化方法目的只有一个——提升训练速度,加速模型收敛;
- BatchNorm 在序列问题上存在局限性,序列长度的不同导致 BatchNorm 计算得到的均值和方差不具有代表性,难以胜任序列任务的归一化方式;
- LayerNorm 能够较好的处理序列任务的归一化,作者对 LayerNorm 的两个核心机制 re-center、re-scale 进行研究,假设 re-scale 是 LayerNorm 最主要的机制,并通过实验验证猜想提出 RMSNorm,具有更少的计算量,更快的训练效率,可匹敌的收敛性能;
- 进一步提出 p − RMSNorm p-\text{RMSNorm} p−RMSNorm,基于独立同分布的假设,使用 p % p\% p% 的参数估计整体的缩放系数,理论上进一步减少了计算量,但实验发现训练效率并未显著提升;
研究内容
LayerNorm \text{LayerNorm} LayerNorm
假设一个标准的前馈神经网络(线性层),输入向量
x
∈
R
m
x \in \mathbb{R}^{m}
x∈Rm,输出向量
y
∈
R
n
y \in \mathbb{R}^{n}
y∈Rn,神经网络的工作原理可由如下公式表示
a
i
=
∑
j
=
1
m
w
i
j
x
j
,
y
i
=
f
(
a
i
+
b
i
)
a_{i}=\sum_{j=1}^{m} w_{i j} x_{j}, y_{i}=f\left(a_{i}+b_{i}\right)
ai=j=1∑mwijxj,yi=f(ai+bi)其中
x
j
x_j
xj 是输入向量第
j
j
j 个值,
w
i
j
w_{ij}
wij (矩阵
n
×
m
n\times m
n×m)是第
i
i
i 个神经元对
x
j
x_j
xj 的权重,
f
(
⋅
)
f(\cdot)
f(⋅) 表示非线性激活(逐元相乘),
b
i
b_i
bi 代表第
i
i
i 个神经元的偏置 (bias)
神经网络模型基于独立同分布的假设,如果不显式的干预中间层的分布,神经网路模型会遭遇内部斜变量偏移的问题(每层分布的变化随着网络层的加深其效应会被放大,导致不同层的分布不仅相同,违背了基于独立同分布的假设,并且模型难以训练和收敛【一种分布其实就是一种学习到的推理特征】),如下公式描述了 LayerNorm 均值和方差的计算方式
a
‾
i
=
a
i
−
μ
σ
g
i
,
y
i
=
f
(
a
‾
i
+
b
i
)
\overline{a}_{i}=\frac{a_{i}-\mu}{\sigma} g_{i}, y_{i}=f\left(\overline{a}_{i}+b_{i}\right)
ai=σai−μgi,yi=f(ai+bi)其中,
a
ˉ
i
\bar{a}_{i}
aˉi 是向量
a
‾
∈
R
n
\overline{a} \in \mathbb{R}^{n}
a∈Rn 的第
i
i
i 个值,
g
∈
R
n
g \in \mathbb{R}^{n}
g∈Rn 表示增益参数初始设置为1,
μ
μ
μ 和
σ
2
\sigma^{2}
σ2 则是计算得到的均值和方差
RMSNorm \text{RMSNorm} RMSNorm
RMSNorm 公式如下,仅对数据分布做缩放变化:
a
‾
i
=
a
i
R
M
S
(
a
)
g
i
,
其中
R
M
S
(
a
)
=
1
n
∑
i
=
1
n
a
i
2
\overline{a}_{i}=\frac{a_{i}}{RMS(a)} g_{i}, \text{其中 } RMS(a)=\sqrt{\frac{1}{n} \sum_{i=1}^{n} a_{i}^{2}}
ai=RMS(a)aigi,其中 RMS(a)=n1i=1∑nai2
不变性分析
不变性衡量的是归一化后的模型输出是否会随着输入和权重矩阵的变化而产生显著改变,不同的归一化方法展现出不同的不变性特征,这对模型的稳健性有着重要的影响。考虑一个 RMSNorm 的一般形式:
y
=
f
(
W
x
R
M
S
(
a
)
⊙
g
+
b
)
y=f\left(\frac{W x}{RMS(a)} \odot g+b\right)
y=f(RMS(a)Wx⊙g+b)其中
⊙
\odot
⊙ 表示逐元相乘,假设权重矩阵缩放
δ
\delta
δ 倍,即
W
′
=
δ
W
W'=\delta W
W′=δW,可做如下推导:
y
′
=
f
(
W
′
x
R
M
S
(
a
′
)
⊙
g
+
b
)
=
f
(
δ
W
x
R
M
S
(
δ
a
)
⊙
g
+
b
)
=
f
(
δ
W
x
δ
R
M
S
(
a
)
⊙
g
+
b
)
=
y
y'=f\left(\frac{W' x}{RMS\left(a'\right)} \odot g+b\right)=f\left(\frac{\delta Wx}{RMS\left(\delta a\right)} \odot g+b\right)=f\left(\frac{\delta Wx}{\delta RMS\left(a\right)} \odot g+b\right)=y
y′=f(RMS(a′)W′x⊙g+b)=f(RMS(δa)δWx⊙g+b)=f(δRMS(a)δWx⊙g+b)=y
由于 R M S ( δ a ) = 1 n ∑ i = 1 n ( δ a i ) 2 = δ 2 n ∑ i = 1 n ( a i ) 2 = δ 1 n ∑ i = 1 n ( δ a i ) 2 = δ R M S ( a ) RMS(\delta a)=\sqrt{\frac{1}{n} \sum_{i=1}^{n} (\delta a_{i})^{2}}=\sqrt{\frac{\delta ^2}{n} \sum_{i=1}^{n} (a_{i})^{2}}=\delta\sqrt{\frac{1}{n} \sum_{i=1}^{n} (\delta a_{i})^{2}}=\delta RMS(a) RMS(δa)=n1∑i=1n(δai)2=nδ2∑i=1n(ai)2=δn1∑i=1n(δai)2=δRMS(a)
因此可以发现,当权重矩阵缩放时,输出具有不变性,同理将输入缩放 δ \delta δ 倍,输出依然具有不变性
因为 a = W x a=Wx a=Wx,因此无论是权重还是输入进行缩放, a a a 都会同步缩放,因此输出结果依然具有不变性
如果仅对输出中的部分或权重的部分进行缩放,输出结果则不具有不变性,因为 RMS 的线性性质被破坏
简单举例, A = [ 1 , 2 , 3 ] A=[1,2,3] A=[1,2,3], 现在对部分元素缩放得到 A ′ = [ 1 , 2 ∗ 1.5 , 3 ] A'=[1,2*1.5,3] A′=[1,2∗1.5,3], 2.16 = R M S ( A ) ≠ R M S ( A ′ ) = 2.52 2.16=RMS(A) \neq RMS(A')=2.52 2.16=RMS(A)=RMS(A′)=2.52
梯度分析
模型梯度的稳健性对于参数更新和模型收敛至关重要(论文 How Does Batch Normalization Help Optimization? 指出归一化方法之所以成功,并非源于为层输入增加了稳定性,而是因为优化态势的平滑度有所提高),这里将对 RMSNorm 进行梯度分析。给定一个损失函数
L
\mathcal{L}
L,根据链式法则,求
L
\mathcal{L}
L 对
g
g
g、
b
b
b 的导数如下:
∂
L
∂
b
=
∂
L
∂
v
⊙
∂
v
∂
b
=
∂
L
∂
v
,
∂
L
∂
g
=
∂
L
∂
v
⊙
W
x
R
M
S
(
a
)
\frac{\partial \mathcal{L}}{\partial b}=\frac{\partial \mathcal{L}}{\partial v}\odot \frac{\partial v}{\partial b}=\frac{\partial \mathcal{L}}{\partial v}, \frac{\partial \mathcal{L}}{\partial g}=\frac{\partial \mathcal{L}}{\partial v} \odot \frac{W x}{RMS(a)}
∂b∂L=∂v∂L⊙∂b∂v=∂v∂L,∂g∂L=∂v∂L⊙RMS(a)Wx其中
v
=
W
x
R
M
S
(
a
)
⊙
g
+
b
v=\frac{W x}{RMS(a)} \odot g+b
v=RMS(a)Wx⊙g+b
根据 v = W x R M S ( a ) ⊙ g + b v=\frac{W x}{RMS(a)} \odot g+b v=RMS(a)Wx⊙g+b 可得 ∂ v ∂ b = 1 \frac{\partial v}{\partial b}=1 ∂b∂v=1,所以 ∂ L ∂ b = ∂ L ∂ v \frac{\partial \mathcal{L}}{\partial b}=\frac{\partial \mathcal{L}}{\partial v} ∂b∂L=∂v∂L,此外这里的 L \mathcal{L} L 并非一个真正的损失函数,而是 RMSNorm 的输出结果,在反向传播时,对 RMSNorm 中的变量求导数
此外我们将分析 RMSNorm 对权重
W
W
W 的梯度
∂
L
∂
W
=
∑
i
=
1
n
[
x
T
⊗
(
d
i
a
g
(
g
⊙
∂
L
∂
v
)
×
R
)
]
i
,
其中
R
=
1
R
M
S
(
a
)
(
I
−
(
W
x
)
(
W
x
)
T
n
R
M
S
(
a
)
2
)
\frac{\partial \mathcal{L}}{\partial W}=\sum_{i=1}^{n}\left[x^{T} \otimes\left(diag\left(g \odot \frac{\partial \mathcal{L}}{\partial v}\right) × R\right)\right]_{i}, \text{其中 } R=\frac{1}{RMS(a)}\left(I-\frac{(W x)(W x)^{T}}{n RMS(a)^{2}}\right)
∂W∂L=i=1∑n[xT⊗(diag(g⊙∂v∂L)×R)]i,其中 R=RMS(a)1(I−nRMS(a)2(Wx)(Wx)T)
这里进行示例推导,推导过程以标量形式示例,矩阵导数类似:
- 根据链式法则 ∂ L ∂ W = ∂ L ∂ v ∂ v ∂ W , v = W x R M S ( a ) ⊙ g + b \frac{\partial \mathcal{L}}{\partial W}=\frac{\partial \mathcal{L}}{\partial v}\frac{\partial \mathcal{v}}{\partial W},\mathcal{v}=\frac{W x}{RMS(a)} \odot g+b ∂W∂L=∂v∂L∂W∂v,v=RMS(a)Wx⊙g+b
- ∂ v ∂ W = ∂ ( W x R M S ( a ) ⊙ g + b ) ∂ W = g ⊙ ∂ ( W x R M S ( a ) ) ∂ W \frac{\partial \mathcal{v}}{\partial W}=\frac{\partial (\frac{W x}{RMS(a)} \odot g+b)}{\partial W}=g\odot\frac{\partial (\frac{W x}{RMS(a)})}{\partial W} ∂W∂v=∂W∂(RMS(a)Wx⊙g+b)=g⊙∂W∂(RMS(a)Wx),这里 b , g b,g b,g 相对于 W W W 为常数
- ∂ ( W x R M S ( a ) ) ∂ W = x R M S ( a ) − W ∂ R M S ( a ) ∂ W R M S 2 ( a ) \frac{\partial (\frac{W x}{RMS(a)})}{\partial W}=x\frac{RMS(a)-W\frac{\partial RMS(a)}{\partial W}}{{RMS^2(a)}} ∂W∂(RMS(a)Wx)=xRMS2(a)RMS(a)−W∂W∂RMS(a)
- ∂ R M S ( a ) ∂ W = ∂ R M S ( a ) ∂ a ∂ a ∂ W \frac{\partial RMS(a)}{\partial W}=\frac{\partial RMS(a)}{\partial a}\frac{\partial a}{\partial W} ∂W∂RMS(a)=∂a∂RMS(a)∂W∂a,由于 a = W x a=Wx a=Wx,所以 ∂ a ∂ W = x \frac{\partial a}{\partial W}=x ∂W∂a=x,假设 s = ∑ i = 1 n a i 2 s=\sum_{i=1}^{n}a_i^2 s=∑i=1nai2,则 ∂ R M S ( a ) ∂ a = ∂ R M S ( a ) ∂ s ∂ s ∂ a \frac{\partial RMS(a)}{\partial a}=\frac{\partial RMS(a)}{\partial s}\frac{\partial s}{\partial a} ∂a∂RMS(a)=∂s∂RMS(a)∂a∂s,又 R M S ( a ) = s n RMS(a)=\sqrt {\frac{s}{n}} RMS(a)=ns,所以 ∂ R M S ( a ) ∂ s = 1 2 n 1 s = 1 2 n ⋅ R M S ( a ) \frac{\partial RMS(a)}{\partial s}=\frac{1}{2\sqrt n}\frac{1}{\sqrt s}=\frac{1}{2n\cdot RMS(a)} ∂s∂RMS(a)=2n1s1=2n⋅RMS(a)1,又 ∂ s ∂ a = 2 a \frac{\partial s}{\partial a}=2a ∂a∂s=2a,所以 ∂ R M S ( a ) ∂ W = 1 2 n ⋅ R M S ( a ) ⋅ x ⋅ 2 a = a ⋅ x n ⋅ R M S ( a ) = W x ⋅ x n ⋅ R M S ( a ) \frac{\partial RMS(a)}{\partial W}=\frac{1}{2n\cdot RMS(a)}\cdot x \cdot 2a=\frac{a \cdot x}{n\cdot RMS(a)}=\frac{Wx \cdot x}{n\cdot RMS(a)} ∂W∂RMS(a)=2n⋅RMS(a)1⋅x⋅2a=n⋅RMS(a)a⋅x=n⋅RMS(a)Wx⋅x
- 经过上述步骤,根据链式法则进行组合逐步往上带入可得 ∂ ( W x R M S ( a ) ) ∂ W = x R M S ( a ) − W ∂ R M S ( a ) ∂ W R M S 2 ( a ) = ∂ ( W x R M S ( a ) ) ∂ W = x R M S ( a ) − W W x ⋅ x n ⋅ R M S ( a ) R M S 2 ( a ) = x ⋅ 1 R M S ( a ) ( 1 − ( W x ) ( W x ) n R M S 2 ( a ) ) \frac{\partial (\frac{W x}{RMS(a)})}{\partial W}=x\frac{RMS(a)-W\frac{\partial RMS(a)}{\partial W}}{{RMS^2(a)}}=\frac{\partial (\frac{W x}{RMS(a)})}{\partial W}=x\frac{RMS(a)-W\frac{Wx \cdot x}{n\cdot RMS(a)}}{{RMS^2(a)}}=x\cdot \frac{1}{RMS(a)}(1-\frac{(Wx)(Wx)}{nRMS^2(a)}) ∂W∂(RMS(a)Wx)=xRMS2(a)RMS(a)−W∂W∂RMS(a)=∂W∂(RMS(a)Wx)=xRMS2(a)RMS(a)−Wn⋅RMS(a)Wx⋅x=x⋅RMS(a)1(1−nRMS2(a)(Wx)(Wx))
- 因此 ∂ L ∂ W = ∂ L ∂ v ⋅ g ⋅ x ⋅ 1 R M S ( a ) ( 1 − ( W x ) ( W x ) n R M S 2 ( a ) ) = x ⋅ g ⋅ ∂ L ∂ v ⋅ 1 R M S ( a ) ( 1 − ( W x ) ( W x ) n R M S 2 ( a ) ) \frac{\partial \mathcal{L}}{\partial W}=\frac{\partial \mathcal{L}}{\partial v}\cdot g \cdot x\cdot \frac{1}{RMS(a)}(1-\frac{(Wx)(Wx)}{nRMS^2(a)})=x \cdot g \cdot \frac{\partial \mathcal{L}}{\partial v} \cdot \frac{1}{RMS(a)}(1-\frac{(Wx)(Wx)}{nRMS^2(a)}) ∂W∂L=∂v∂L⋅g⋅x⋅RMS(a)1(1−nRMS2(a)(Wx)(Wx))=x⋅g⋅∂v∂L⋅RMS(a)1(1−nRMS2(a)(Wx)(Wx))
- 最后考虑到矩阵形式,整理结果 ∂ L ∂ W = ∑ i = 1 n [ x T ⊗ ( d i a g ( g ⊙ ∂ L ∂ v ) × R ) ] i \frac{\partial \mathcal{L}}{\partial W}=\sum_{i=1}^{n}\left[x^{T} \otimes\left(diag\left(g \odot \frac{\partial \mathcal{L}}{\partial v}\right) × R\right)\right]_{i} ∂W∂L=∑i=1n[xT⊗(diag(g⊙∂v∂L)×R)]i,其中 R = 1 R M S ( a ) ( I − ( W x ) ( W x ) T n R M S ( a ) 2 ) R=\frac{1}{RMS(a)}\left(I-\frac{(W x)(W x)^{T}}{n RMS(a)^{2}}\right) R=RMS(a)1(I−nRMS(a)2(Wx)(Wx)T)
将输入
x
x
x 和权重矩阵
W
W
W 进行缩放,讨论 RMSNorm 的梯度不变性,我们将
δ
x
\delta x
δx 或者
δ
W
\delta W
δW 代入
R
R
R 得到
R
′
=
1
δ
R
M
S
(
a
)
(
I
−
(
δ
W
x
)
(
δ
W
x
)
T
n
δ
2
R
M
S
(
a
)
2
)
=
1
δ
R
R'=\frac{1}{\delta RMS(a)}\left(I-\frac{(\delta W x)(\delta W x)^{T}}{n \delta^{2} RMS(a)^{2}}\right)=\frac{1}{\delta} R
R′=δRMS(a)1(I−nδ2RMS(a)2(δWx)(δWx)T)=δ1R
所以
- 将 δ x \delta x δx 代入 ∂ L ∂ W = ∑ i = 1 n [ δ x T ⊗ ( d i a g ( g ⊙ ∂ L ∂ v ) × 1 δ R ) ] i = ∂ L ∂ W \frac{\partial \mathcal{L}}{\partial W}=\sum_{i=1}^{n}\left[ \delta x^{T} \otimes\left(diag\left(g \odot \frac{\partial \mathcal{L}}{\partial v}\right) × \frac{1}{\delta}R\right)\right]_{i}=\frac{\partial \mathcal{L}}{\partial W} ∂W∂L=∑i=1n[δxT⊗(diag(g⊙∂v∂L)×δ1R)]i=∂W∂L,表明 RMSNorm 的梯度对于输入 x x x 具有缩放不变性,降低梯度对输入缩放的敏感度可确保其平滑性,并提高学习的稳定性
- 同理将 δ x \delta x δx 代入 ∂ L ∂ W \frac{\partial \mathcal{L}}{\partial W} ∂W∂L 得到 1 δ ∂ L ∂ W \frac{1}{\delta}\frac{\partial \mathcal{L}}{\partial W} δ1∂W∂L,表明 RMSNorm 的梯度对于权重矩阵 W W W 具有反比性质,这种负相关关系起到了一种隐性学习率调节器的作用,能够动态地控制梯度的范数,从而避免出现大范数的权重矩阵,并促进模型的收敛【权重过大,梯度会更小,权重过小,梯度会过大】
p RMSNorm p\text{RMSNorm} pRMSNorm
用部分数据来描述整体的分布,即将部分值用于计算 R M S ( ⋅ ) RMS(\cdot) RMS(⋅), R M S ‾ ( a ) = 1 k ∑ i = 1 k a i 2 \overline{RMS}(a)=\sqrt{\frac{1}{k} \sum_{i=1}^{k} a_{i}^{2}} RMS(a)=k1∑i=1kai2,其中 k = ⌈ n ⋅ % p ⌉ k=\lceil n \cdot \%p\rceil k=⌈n⋅%p⌉ 表示用于均方根估计的元素数量, ⌈ ⋅ ⌉ \lceil \cdot \rceil ⌈⋅⌉ 表示向上取整,不难证明 p RMSNorm p\text{RMSNorm} pRMSNorm 与 RMSNorm \text{RMSNorm} RMSNorm 具有相同的性质
作者指出,当 m m m 较小时,梯度容易出现爆炸现象,我感觉这里像是笔误,这里的 m m m 可能是输入 x x x 的维度,也可能是 p p p 写成了 m m m,此外,作者通过实验发现,当采用 p = 6.25 % p=6.25\% p=6.25% 时能够成功实现令人满意的收敛
实验结果
作者做了非常充足的实验,包含在不同的框架(PyTorch、Tensorflow、Theano )、不同的模型架构(RNN 变体、卷积和自注意力模型),不同的激活函数(sigmoid、tanh、softmax)、不同的归一化方式(无归一化、LayerNorm)等,作者实验太多了,这里就不贴图了,直接看下方结论部分
结论
- 实验证实了作者的猜想,在 LayerNorm 中,相比于 re-centering,re-scaling 起主要作用;
- RMSNorm \text{RMSNorm} RMSNorm 没有 re-centering,计算量更少,训练速度更快,并且精度不低于 LayerNorm \text{LayerNorm} LayerNorm,根据不同的网络架构、训练框架、硬件,提升性能在 7 % − 64 % 7\% -64\% 7%−64%;
- p RMSNorm p\text{RMSNorm} pRMSNorm 计算量低于 RMSNorm \text{RMSNorm} RMSNorm,在各项试验中,精度也低于 RMSNorm \text{RMSNorm} RMSNorm,但是训练速度并未显著提升,甚至可能更慢,作者认为可能是底层张量的切片算法拉低了训练效率;
- 序列任务中,不引入归一化、有概率导致梯度爆炸,训练失败,序列任务中精度 RMSNorm > LayerNorm > BatchNorm \text{RMSNorm}>\text{LayerNorm}>\text{BatchNorm} RMSNorm>LayerNorm>BatchNorm;
- 图像任务精度 BatchNorm > RMSNorm > LayerNorm \text{BatchNorm}>\text{RMSNorm}>\text{LayerNorm} BatchNorm>RMSNorm>LayerNorm;
- RMSNorm 对输入具有缩放不变性,降低梯度对输入缩放的敏感度可确保其平滑性,并提高学习的稳定性;
- RMSNorm 的梯度对于权重矩阵具有反比性质,权重过大,梯度更小,权重过小,梯度更大,能够动态地控制梯度的范数,从而避免出现大范数的权重矩阵;