均方根层标准化(RMSNorm: Root Mean Square Layer Normalization)

0 TL;DR

LayerNorm的重新中心化可能不是必要的,RMSNorm移除了重新中心化,降低了计算量。实验显示,训练效果和稳定性也更好

1 背景

LayerNorm存在什么问题? 标准化目的是使训练更快,但标准化增加了计算量,降低了标准化带来的训练速度收益。

如果LayerNorm的中心化不是必要的,移除中心化是不是就减少了计算量!

2 理论

先来回顾一下LayerNorm:
神经网络的前馈网络通过线性变化➕非线性激活对输入进行投影变换:
a i = ∑ j = 1 m w i , j x j , y i = f ( a i + b i ) a_i=\sum_{j=1}^{m}w_{i,j}x_j, \quad y_i=f(a_i+b_i) ai=j=1mwi,jxj,yi=f(ai+bi)

但后续网络层的输入分布会变化,出现协变量偏移问题,降低了模型收敛速度。LayerNorm对加权输入 a \boldsymbol a a进行归一化,固定其均值和方差:
a ˉ i = a i − μ σ g i , y i = f ( a ˉ i + b i ) \bar a_i=\frac{a_i - \mu}{\sigma}g_i, \quad y_i=f(\bar a_i+b_i) aˉi=σaiμgi,yi=f(aˉi+bi)

其中, μ \mu μ σ \sigma σ分别是加权输入 a \boldsymbol a a的均值和标准差估计量。

LayerNorm具有重新中心化和重新缩放不变性。重新中心化对输入和权重的偏移噪声不敏感(抗偏移),重新缩放则能在输入和权重随机缩放时保证输出不变(抗伸缩)。

RMSNorm仅关注重新缩放不变性,根据均方根(RMS)统计量对加权输入进行正则化:
a ˉ i = a i RMS ( a ) g i , where RMS ( a ) = 1 n ∑ i = 1 1 a i 2 \bar a_i=\frac{a_i}{\text{RMS}(\boldsymbol a)}g_i,\quad \text{where}\ \text{RMS}(\boldsymbol a)=\sqrt{\frac{1}{n}\sum_{i=1}^1a_i^2} aˉi=RMS(a)aigi,where RMS(a)=n1i=11ai2
当加权输入 a \boldsymbol a a的均值为零时,LayerNorm和RMSNorm完全等价。

由于RMSNorm不需要计算均值,简化了计算量,提升了训练速度!假设中心化对训练的影响很小,这就是可行的!

伸缩不变性(re-scaling invariance)
RMS具有性质: RMS ( α x ) = α RMS ( x ) \text{RMS}(\alpha \boldsymbol x) = \alpha\text{RMS}(\boldsymbol x) RMS(αx)=αRMS(x),因此,对于RMSNorm的通用形式:
y = f ( W x RMS ( a ) ⊙ g + b ) \boldsymbol y=f\bigg(\frac{W\boldsymbol x}{\text{RMS}(\boldsymbol a)}\odot\boldsymbol g+\boldsymbol b\bigg) y=f(RMS(a)Wxg+b)

有:
y ′ = f ( α W x RMS ( α W x ) ⊙ g + b ) = f ( W x RMS ( W x ) ⊙ g + b ) = y ′ \boldsymbol y' =f\bigg(\frac{\alpha W \boldsymbol x}{\text{RMS}(\alpha W \boldsymbol x)}\odot\boldsymbol g+\boldsymbol b\bigg) =f\bigg(\frac{W\boldsymbol x}{\text{RMS}(W x)}\odot\boldsymbol g+\boldsymbol b\bigg) =\boldsymbol y' y=f(RMS(αWx)αWxg+b)=f(RMS(Wx)Wxg+b)=y

3 效果


4 LayerNorm的重新中心化到底有没有用?

如下图所示,将网络权重均值随机初始化为0.2,LayerNorm收敛非常慢,但RMSNorm仍正常工作,可见 RMSNorm对权重初始均值的变化更加鲁棒

5 Torch源码

Qwen的实现,与T5 LayerNorm一致。

class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值