【Llama源码】归一化RMSNorm

数学公式与代码

RMSNorm是在Layer Norm之上的改进,它通过舍弃中心不变性来降低计算量。

a ‾ i = a i R M S ( a ) g i 其中, R M S ( a ) = 1 n ∑ r = 1 n a i 2 \overline a_i = \frac {a_i}{RMS(a)} g_i \\ 其中,RMS(a)=\sqrt { { \frac1n}{\sum_{r=1}^n a_i^2}} ai=RMS(a)aigi其中,RMS(a)=n1r=1nai2

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
    	# 求RMS(a)
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        # 计算 ai/RMS(a) 
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 

        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)
		# 计算 ai/RMS(a) * gi
        return self.weight * hidden_states 

torch.rsqrt(input, *, out=None) 函数

针对输入input的每个元素的平方根的倒数来返回一个新的Tensor。

o u t i = 1 i n p u t i out_i =\frac1{\sqrt {input_i}} outi=inputi 1

  • Example
x = torch.tensor([1,4,9,16])
torch.rsqrt(x)
  • Result
tensor([1.0000, 0.5000, 0.3333, 0.2500])

参考链接

llama 代码详读 - RMSNorm
RMSNorm的原理和代码

  • 13
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值