【HuggingFace Transformers】LlamaRMSNorm源码解析

1. LlamaRMSNorm 介绍

LlamaRMSNormLLaMA 模型中的一种归一化层(Normalization Layer),类似于标准的层归一化(LayerNorm),但它使用的是均方根归一化(Root Mean Square Normalization,RMSNorm)。RMSNorm 是一种更轻量级的归一化方法,能够在保持性能的前提下减少计算开销。

1.1 归一化的目的

归一化层在神经网络中起到稳定训练过程、加速收敛的作用。它通过对输入数据进行归一化,减少不同输入特征的方差,进而提高模型的训练效果。传统的层归一化通常计算输入的均值和方差,然后进行归一化,而 RMSNorm 只计算输入的均方根(RMS),从而简化了计算。

1.2 RMSNorm 的计算

LlamaRMSNorm 中,输入张量的归一化过程如下:

  • 计算均方根:对于输入张量 xRMSNorm 计算每个样本的均方根值。均方根值是指输入张量各个维度上所有元素的平方和的均值的平方根,即:
    R M S ( x ) = 1 n ∑ i = 1 n x i 2 RMS(x)=\sqrt{\frac{1}{n} \sum_{i=1}^{n} x_{i}^{2}} RMS(x)=n1i=1nxi2
  • 归一化:将输入张量除以均方根值,再乘以一个可学习的缩放参数 weight,从而得到归一化后的输出。RMSNorm 的归一化公式可以表示为:
    y = x R M S ( x ) + ε ⋅ w e i g h t y=\frac{x}{RMS(x)+\varepsilon }\cdot weight y=RMS(x)+εxweight
    其中,weight 是一个与输入张量维度相同的可学习参数,用于对归一化后的输出进行缩放。 ε \varepsilon ε是防止除零错误的极小常数。

1.3 与 LayerNorm 的比较

  • 计算简化:相比 LayerNorm 需要计算均值和方差,RMSNorm 只计算均方根值,因此计算更为简便,适合在较大的模型中使用。
  • 性能表现:RMSNorm 在许多场景下能够与 LayerNorm 达到类似的性能表现,尤其是在大型语言模型中,RMSNorm 的计算效率优势更加明显。

2. LlamaRMSNorm类 源码解析

源码地址:transformers/src/transformers/models/llama/modeling_llama.py

# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:52
import torch

from torch import nn


class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))  # 可学习的缩放参数,形状与输入张量的最后一个维度相同
        self.variance_epsilon = eps  # 用于计算均方根时防止除零错误的极小常数,默认为 1e-6

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype  # 保存输入张量的原始数据类型
        hidden_states = hidden_states.to(torch.float32)  # 将输入张量转换为 float32 类型,以提高计算精度
        variance = hidden_states.pow(2).mean(-1, keepdim=True)  # 计算输入张量在最后一个维度上的方差
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)  # 使用方差和 epsilon 进行归一化处理
        return self.weight * hidden_states.to(input_dtype)  # 恢复原始数据类型并乘以可学习的缩放参数

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"  # 返回用于打印模型时的额外信息,包括 weight 的形状和 epsilon 的值
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CS_木成河

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值