一、LLaMA的核心改进全景
Meta开源的LLaMA模型凭借其卓越的性能表现成为大模型发展的重要里程碑。相较于标准Transformer架构,LLaMA主要在以下几个方面进行了关键改进:
- 位置编码升级:采用旋转位置编码(Rotary Position Embedding, RoPE)
- 归一化革新:对每个Transformer子层的输入进行归一化(Pre-normalization)而非传统Transformer结构中对输出进行归一化(Post - normalization),并使用RMS-Norm替代传统LayerNorm。
- 激活函数优化:引入SwiGLU激活函数取代ReLU非线性函数,以提高性能。
- 注意力优化(LLaMA 2):引入分组查询注意力(Grouped Query Attention)
这些改进显著提升了模型的计算效率和长文本处理能力,今天我们来学习一下均方根误差标准化(RMS-Norm)。
其余部件的学习链接持续更新中,欢迎关注:
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之旋转编码RoPE(含代码实现)
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之均方根误差标准化RMSNorm(含代码实现)
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之SwiGLU激活函数(含代码实现)
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之分组查询注意力(Grouped Query Attention)(含代码实现)
二、均方根误差标准化(RMS-Norm)
2.1 改进动机
在深度学习中,归一化(Normalization)技术对于加速模型训练和提升泛化能力至关重要。传统的 Transformer 架构通常采用 LayerNorm 进行归一化操作。在原始 Transformer 中,使用的是 PostNorm,即对每个子层的输出进行归一化。但这种方式在训练深层模型时,梯度容易出现不稳定的情况,导致训练效率低下和收敛困难等问题。而 LLaMA 改用 PreNorm,也就是对每个 Transformer 子层的输入进行归一化,这样可以有效缓解梯度不稳定的问题,使得模型在训练过程中梯度能够更加平稳地传播,进而加快训练速度,提高训练的稳定性。
然而,LayerNorm在计算过程中需要计算输入特征的均值,这在某些情况下可能会引入不必要的计算开销,并且均值的计算可能会对输入特征的信息造成一定的损失。为了解决这些问题,RMSNorm作为一种改进的归一化方法应运而生。RMSNorm仅计算输入特征的均方根(Root Mean Square),而不考虑均值。这种方法不仅减少了计算开销,还能更好地保留输入特征的信息,从而在某些任务上取得更好的性能。
从缩放(scaling)和平移(shifting)的角度来看,LayerNorm的主要收益来自于缩放操作,而非平移操作。具体而言,LayerNorm通过计算输入特征的均值和方差,对输入进行标准化,然后应用可学习的缩放因子(γ)和偏移量(β)来调整输出。然而,研究表明,LayerNorm的性能提升主要源自于缩放操作,而平移操作对性能的影响相对较小。这意味着,去除平移操作可能不会显著影响模型性能。
基于这一发现,RMSNorm提出了仅进行缩放操作的归一化方法。RMSNorm通过计算输入特征的均方根,对输入进行标准化,然后应用可学习的缩放因子(γ)来调整输出。这种方法简化了计算过程,减少了计算开销,同时保留了LayerNorm的主要优势。RMSNorm在多个任务上表现出色,尤其是在处理长序列数据时,能够显著提高训练效率和模型性能。例如,在自然语言处理任务中,使用RMSNorm的模型在训练速度和泛化能力上均优于使用LayerNorm的模型。
2.2 数学原理
RMS-Norm的计算过程相对简单。给定一个输入向量 x = { x 1 , x 2 , … , x d } x = \{x_1, x_2, \ldots, x_d\} x={x1,x2,…,xd},其中 d d d 是向量的维度。RMS-Norm的计算公式如下:
首先,计算输入向量
x
x
x 的均方根:
r
m
s
(
x
)
=
1
d
∑
i
=
1
d
x
i
2
rms(x) = \sqrt{\frac{1}{d}\sum_{i = 1}^{d}x_{i}^{2}}
rms(x)=d1i=1∑dxi2
然后,对输入向量
x
x
x 进行归一化操作:
x
^
=
x
r
m
s
(
x
)
\hat{x} = \frac{x}{rms(x)}
x^=rms(x)x
最后,通过一个可学习的缩放参数
γ
\gamma
γ 对归一化后的向量进行缩放:
y
=
γ
⋅
x
^
y = \gamma \cdot \hat{x}
y=γ⋅x^
2.3 源码实现
下面是RMS-Norm的Python代码实现,使用PyTorch框架:
import torch
import torch.nn as nn
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
初始化RMSNorm归一化层。
参数:
dim (int): 输入张量的维度。通常对应输入特征的数量,例如在形状为 (batch_size, sequence_length, feature_dim) 的输入张量中,dim 应为 feature_dim。
eps (float, 可选): 一个小的数值,添加到分母中以保证数值稳定性,防止除零错误。默认值为 1e-6。
属性:
eps (float): 用于数值稳定性的小数值,添加到分母中。
weight (nn.Parameter): 可学习的缩放参数,形状为 (dim,),初始值为全 1。在训练过程中会自动更新,以调整归一化后输出的尺度。
"""
super().__init__()
# 保存用于数值稳定的小常量
self.eps = eps
# 定义可学习的缩放参数,初始化为全 1
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
对输入张量应用RMSNorm归一化操作。
该方法的具体步骤为:首先计算输入张量在最后一个维度上的平方的均值,然后加上 eps 以保证数值稳定性,接着求其平方根的倒数,最后将输入张量与该倒数相乘完成归一化。
参数:
x (torch.Tensor): 待归一化的输入张量,其形状可以是任意的,但最后一个维度的大小应与初始化时的 dim 参数相等。
返回:
torch.Tensor: 归一化后的张量,形状与输入张量相同。
"""
# 计算输入张量每个元素的平方
squared_x = x.pow(2)
# 计算平方后张量在最后一个维度上的均值,keepdim=True 保持维度不变,方便后续广播操作
mean_squared = squared_x.mean(-1, keepdim=True)
# 为均值加上 eps 以防止除零错误
rms_denominator = mean_squared + self.eps
# 计算 rms_denominator 的平方根的倒数
rms_reciprocal = torch.rsqrt(rms_denominator)
# 将输入张量与 rms 倒数相乘,完成归一化
return x * rms_reciprocal
def forward(self, x):
"""
RMSNorm层的前向传播过程。
该方法首先调用 _norm 方法对输入张量进行归一化,然后将归一化后的张量转换回与输入张量相同的数据类型,最后与可学习的缩放参数 weight 相乘得到最终输出。
参数:
x (torch.Tensor): 输入张量。
返回:
torch.Tensor: 经过RMSNorm处理后的输出张量。
"""
# 将输入张量转换为浮点数类型进行归一化操作
normalized_output = self._norm(x.float())
# 将归一化后的输出张量转换回与输入张量相同的数据类型
normalized_output = normalized_output.type_as(x)
# 将归一化后的输出与可学习的缩放参数相乘
return normalized_output * self.weight
if __name__ == "__main__":
# 定义输入张量的维度
dim = 10
# 创建RMSNorm实例
rms_norm = RMSNorm(dim)
# 生成一个随机输入张量,形状为 (2, 5, dim)
input_tensor = torch.randn(2, 5, dim)
# 打印输入张量的形状
print(f"输入张量的形状: {input_tensor.shape}")
# 将输入张量传入RMSNorm进行归一化处理
output_tensor = rms_norm(input_tensor)
# 打印输出张量的形状
print(f"输出张量的形状: {output_tensor.shape}")
# 打印输出张量
print("输出张量:")
print(output_tensor)