LlamaRMSNorm源码解析
1. LlamaRMSNorm 介绍
LlamaRMSNorm 是 LLaMA 模型中的一种归一化层(Normalization Layer),类似于标准的层归一化(LayerNorm),但它使用的是均方根归一化(Root Mean Square Normalization,RMSNorm)。RMSNorm 是一种更轻量级的归一化方法,能够在保持性能的前提下减少计算开销。
1.1 归一化的目的
归一化层在神经网络中起到稳定训练过程、加速收敛的作用。它通过对输入数据进行归一化,减少不同输入特征的方差,进而提高模型的训练效果。传统的层归一化通常计算输入的均值和方差,然后进行归一化,而 RMSNorm 只计算输入的均方根(RMS),从而简化了计算。
1.2 RMSNorm 的计算
在 LlamaRMSNorm 中,输入张量的归一化过程如下:
- 计算均方根:对于输入张量
x
,RMSNorm 计算每个样本的均方根值。均方根值是指输入张量各个维度上所有元素的平方和的均值的平方根,即:
R M S ( x ) = 1 n ∑ i = 1 n x i 2 RMS(x)=\sqrt{\frac{1}{n} \sum_{i=1}^{n} x_{i}^{2}} RMS(x)=n1i=1∑nxi2 - 归一化:将输入张量除以均方根值,再乘以一个可学习的缩放参数
weight
,从而得到归一化后的输出。RMSNorm 的归一化公式可以表示为:
y = x R M S ( x ) + ε ⋅ w e i g h t y=\frac{x}{RMS(x)+\varepsilon }\cdot weight y=RMS(x)+εx⋅weight
其中,weight
是一个与输入张量维度相同的可学习参数,用于对归一化后的输出进行缩放。 ε \varepsilon ε是防止除零错误的极小常数。
1.3 与 LayerNorm 的比较
- 计算简化:相比 LayerNorm 需要计算均值和方差,RMSNorm 只计算均方根值,因此计算更为简便,适合在较大的模型中使用。
- 性能表现:RMSNorm 在许多场景下能够与 LayerNorm 达到类似的性能表现,尤其是在大型语言模型中,RMSNorm 的计算效率优势更加明显。
2. LlamaRMSNorm类 源码解析
源码地址:transformers/src/transformers/models/llama/modeling_llama.py
# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:52
import torch
from torch import nn
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) # 可学习的缩放参数,形状与输入张量的最后一个维度相同
self.variance_epsilon = eps # 用于计算均方根时防止除零错误的极小常数,默认为 1e-6
def forward(self, hidden_states):
input_dtype = hidden_states.dtype # 保存输入张量的原始数据类型
hidden_states = hidden_states.to(torch.float32) # 将输入张量转换为 float32 类型,以提高计算精度
variance = hidden_states.pow(2).mean(-1, keepdim=True) # 计算输入张量在最后一个维度上的方差
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # 使用方差和 epsilon 进行归一化处理
return self.weight * hidden_states.to(input_dtype) # 恢复原始数据类型并乘以可学习的缩放参数
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" # 返回用于打印模型时的额外信息,包括 weight 的形状和 epsilon 的值