0 TL;DR
LayerNorm的重新中心化可能不是必要的,RMSNorm移除了重新中心化,降低了计算量。实验显示,训练效果和稳定性也更好!
1 背景
LayerNorm存在什么问题? 标准化目的是使训练更快,但标准化增加了计算量,降低了标准化带来的训练速度收益。
如果LayerNorm的中心化不是必要的,移除中心化是不是就减少了计算量!
2 理论
先来回顾一下LayerNorm:
神经网络的前馈网络通过线性变化➕非线性激活对输入进行投影变换:
a
i
=
∑
j
=
1
m
w
i
,
j
x
j
,
y
i
=
f
(
a
i
+
b
i
)
a_i=\sum_{j=1}^{m}w_{i,j}x_j, \quad y_i=f(a_i+b_i)
ai=j=1∑mwi,jxj,yi=f(ai+bi)
但后续网络层的输入分布会变化,出现协变量偏移问题,降低了模型收敛速度。LayerNorm对加权输入
a
\boldsymbol a
a进行归一化,固定其均值和方差:
a
ˉ
i
=
a
i
−
μ
σ
g
i
,
y
i
=
f
(
a
ˉ
i
+
b
i
)
\bar a_i=\frac{a_i - \mu}{\sigma}g_i, \quad y_i=f(\bar a_i+b_i)
aˉi=σai−μgi,yi=f(aˉi+bi)
其中, μ \mu μ和 σ \sigma σ分别是加权输入 a \boldsymbol a a的均值和标准差估计量。
LayerNorm具有重新中心化和重新缩放不变性。重新中心化对输入和权重的偏移噪声不敏感(抗偏移),重新缩放则能在输入和权重随机缩放时保证输出不变(抗伸缩)。
RMSNorm仅关注重新缩放不变性,根据均方根(RMS)统计量对加权输入进行正则化:
a
ˉ
i
=
a
i
RMS
(
a
)
g
i
,
where RMS
(
a
)
=
1
n
∑
i
=
1
1
a
i
2
\bar a_i=\frac{a_i}{\text{RMS}(\boldsymbol a)}g_i,\quad \text{where}\ \text{RMS}(\boldsymbol a)=\sqrt{\frac{1}{n}\sum_{i=1}^1a_i^2}
aˉi=RMS(a)aigi,where RMS(a)=n1i=1∑1ai2
当加权输入
a
\boldsymbol a
a的均值为零时,LayerNorm和RMSNorm完全等价。
由于RMSNorm不需要计算均值,简化了计算量,提升了训练速度!假设中心化对训练的影响很小,这就是可行的!
伸缩不变性(re-scaling invariance)
RMS具有性质:
RMS
(
α
x
)
=
α
RMS
(
x
)
\text{RMS}(\alpha \boldsymbol x) = \alpha\text{RMS}(\boldsymbol x)
RMS(αx)=αRMS(x),因此,对于RMSNorm的通用形式:
y
=
f
(
W
x
RMS
(
a
)
⊙
g
+
b
)
\boldsymbol y=f\bigg(\frac{W\boldsymbol x}{\text{RMS}(\boldsymbol a)}\odot\boldsymbol g+\boldsymbol b\bigg)
y=f(RMS(a)Wx⊙g+b)
有:
y
′
=
f
(
α
W
x
RMS
(
α
W
x
)
⊙
g
+
b
)
=
f
(
W
x
RMS
(
W
x
)
⊙
g
+
b
)
=
y
′
\boldsymbol y' =f\bigg(\frac{\alpha W \boldsymbol x}{\text{RMS}(\alpha W \boldsymbol x)}\odot\boldsymbol g+\boldsymbol b\bigg) =f\bigg(\frac{W\boldsymbol x}{\text{RMS}(W x)}\odot\boldsymbol g+\boldsymbol b\bigg) =\boldsymbol y'
y′=f(RMS(αWx)αWx⊙g+b)=f(RMS(Wx)Wx⊙g+b)=y′
3 效果


4 LayerNorm的重新中心化到底有没有用?
如下图所示,将网络权重均值随机初始化为0.2,LayerNorm收敛非常慢,但RMSNorm仍正常工作,可见 RMSNorm对权重初始均值的变化更加鲁棒 !

5 Torch源码
Qwen的实现,与T5 LayerNorm一致。
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)