LLM上下文长度扩展方案:NTK-aware interpolation

Position Interpolation存在的问题

在之前的一篇文章中讲了位置内插方案:LLM上下文长度扩展方案:Position Interpolation。PI本质上是一种线性内插,即把每个位置均匀压缩为之前的 1 S \frac{1}{S} S1,其中 S = L ′ L S=\frac{L^{'}}{L} S=LL为扩展后长度和原始长度的比值。

在RoPE中,针对 m m m位置的向量 x m x_m xm,其变换方式为:
在这里插入图片描述
其中 θ i = b − 2 ( i − 1 ) / d ( b = 10000 ,   i = 1 , 2 , 3 , . . . , d / 2 ) \theta_i = b^{-2(i - 1) / d}(b=10000,~i=1,2,3,...,d/2) θi=b2(i1)/d(b=10000, i=1,2,3,...,d/2)。PI将上式中所有的 m m m替换成了 m S \frac{m}{S} Sm,因此造成了每两个维度形成的复数在复平面上旋转的角度变成了原来的 1 S \frac{1}{S} S1

高频信息损失

在RoPE中,周期和频率有如下关系:

当base和位置m固定时,维度 i ↓ i \downarrow i,复向量旋转的角度越大,转满一圈的耗时越短,即周期越短,频率越高
当维度 i i i和位置m固定时,base越大, θ i ↓ \theta_i \downarrow θi,转满一圈的耗时越长,即周期越长,频率越低

我们主要研究第一种情况,这种情况可以用八个字概括:低维高频、高维低频。

接着,我们可以发现PI的缩放是平等对待地对待所有维度,即高频旋转角度缩小的倍数和低频旋转角度缩小的倍数是一样的,没有考虑针对不同维度作出不同的缩放。这可能会造成以下问题:对于高频低维度,插值后变得异常拥挤

造成高频低维内插后异常拥挤的原因是我们有 θ i = b − 2 ( i − 1 ) / d ( b = 10000 ,   i = 1 , 2 , 3 , . . . , d / 2 ) \theta_i = b^{-2(i - 1) / d}(b=10000,~i=1,2,3,...,d/2) θi=b2(i1)/d(b=10000, i=1,2,3,...,d/2),旋转角度随着维度 i ↑ i \uparrow i呈现指数级别减小。由于旋转角度随着维度 i ↑ i \uparrow i呈现指数级别减小,我们可以假设从低维到高维的旋转角度依次为1000 100 10 1 0.1,将旋转角度都除以缩放因子 S = 4 S=4 S=4,那么每个维度的旋转角度将依次变为250、25、2.5 、0.25 和 0.025。相比内插之前,我们可以发现维度间变得更拥挤了(差值变小)。

原本在低维度上,旋转角度较大,意味着这些维度上的信号变化非常迅速,能够精细地区分相邻位置,虽然不同维度间的差异都减小了 4 倍,但是高频低维由于量级比较大,所以相比之前降低更多,对用低维区分不同位置间的能力影响更大,显得“更拥挤”。而高维度由于本来旋转角度就很小,这种变化对高维度区分不同位置的能力影响相对较小,不会立即表现为“更拥挤”。因此,PI会在扩展倍数特别大时显著降低位置编码区分不同位置的能力,这种现象称之为高频信息的损失。

NTK-aware Scaled RoPE:高频外推+低频内插

为了解决PI中出现的问题,NTK-aware提出的改进策略为:高频外推和低频内插。即:不是将RoPE的每个维度平均缩放一个因子 S S S,而是通过减少高频的缩放和增加低频的缩放将插值压力分散到多个维度。

在讲NTK-aware之前,为了将PI和其NTK-aware,或者更广义上的多种内插方法联系起来,我们定义了如下表达式:
f ′ ( x m , m , θ ) = f ( x m , g ( m ) , h ( θ ) ) \begin{equation} f^{'}(x_m, m, \theta) = f(x_m, g(m), h(\theta)) \end{equation} f(xm,m,θ)=f(xm,g(m),h(θ))其中第 i i i个维度有 θ i = b − 2 ( i − 1 ) / d ( b = 10000 ,   i = 1 , 2 , 3 , . . . , d / 2 ) \theta_i = b^{-2(i - 1) / d}(b=10000,~i=1,2,3,...,d/2) θi=b2(i1)/d(b=10000, i=1,2,3,...,d/2)。上式从位置和旋转角度两个方面对所有内插方案进行了总结,即所有内插方案都是建立在对二者的变换上的。此时,我们可以将PI改写为:
f P I ( x m , m , θ ) = f ( x m , g ( m ) = m S , h ( θ i ) = θ i ) \begin{equation} f_{PI}(x_m, m, \theta) = f(x_m, g(m)=\frac{m}{S}, h(\theta_i)=\theta_i) \end{equation} fPI(xm,m,θ)=f(xm,g(m)=Sm,h(θi)=θi)即PI中没有对旋转角度做任何改变,只是将位置索引除以扩展比。

基于上式,NTK-aware的做法可以被表述为:
f N T K ( x m , m , θ ) = f ( x m , g ( m ) = m , h ( θ i ) = ( b ⋅ S d / ( d − 2 ) ) − 2 ( i − 1 ) / d ) \begin{equation} f_{NTK}(x_m, m, \theta) = f(x_m, g(m)=m, h(\theta_i)=(b\cdot S^{d/(d-2)})^ {-2(i-1) / d}) \end{equation} fNTK(xm,m,θ)=f(xm,g(m)=m,h(θi)=(bSd/(d2))2(i1)/d)即NTK-aware interpolation本质上就是将原始RoPE中的 θ i = b − 2 ( i − 1 ) / d \theta_i = b^{-2(i - 1) / d} θi=b2(i1)/d改为 h ( θ i ) = ( b ⋅ S d / ( d − 2 ) ) − 2 ( i − 1 ) / d h(\theta_i)=(b\cdot S^{d/(d-2)})^ {-2(i-1) / d} h(θi)=(bSd/(d2))2(i1)/d,更本质的区别是将基数base乘以了一个和扩展比 S S S有关的常量 S d / ( d − 2 ) S^{d/(d-2)} Sd/(d2)

进制编码

为什么简单将基数base乘上 S d / ( d − 2 ) S^{d/(d-2)} Sd/(d2)就能达到高频外推+低频内插呢?这要从位置编码的本质说起:

位置 n n n的RoPE编码,本质上就是数字 n n n β \beta β进制编码。

我们知道,如果要求一个十进制数字 n n n β \beta β进制数的第(从右往左数) m m m位数字,可以用如下方式:
⌊ n β m − 1 ⌋ m o d    β \left \lfloor \frac{n}{\beta^{m-1}} \right \rfloor \mod \beta βm1nmodβ例如数字15的8进制表示是17,第1位数字7可以计算为15/1 mod 8 = 7,第2位数字为15 / 7 mod 8 = 1。然后再来回忆一下Transformer中提出的Sinusoidal位置编码:
在这里插入图片描述
其中 i = 0 , 1 , 2 , . . . , d / 2 − 1 i=0,1,2,...,d/2-1 i=0,1,2,...,d/21,即偶数维度为sin,奇数维度位cos。对于第 n n n个位置,每个维度的值依次为:
[ sin ⁡ ( n ( 1000 0 2 d ) 0 ) , cos ⁡ ( n ( 1000 0 2 d ) 0 ) , sin ⁡ ( n ( 1000 0 2 d ) 1 ) , cos ⁡ ( n ( 1000 0 2 d ) 1 ) , . . . , sin ⁡ ( n ( 1000 0 2 d ) d / 2 − 1 ) , cos ⁡ ( n ( 1000 0 2 d ) d / 2 − 1 ) ] \left [ \sin(\frac{n}{(10000^{ \frac{2}{d} })^{0} }),\cos(\frac{n}{(10000^{ \frac{2}{d} })^{0} }), \sin(\frac{n}{(10000^{ \frac{2}{d} })^{1} }),\cos(\frac{n}{(10000^{ \frac{2}{d} })^{1} }), ...,\sin(\frac{n}{(10000^{ \frac{2}{d} })^{d/2-1} }),\cos(\frac{n}{(10000^{ \frac{2}{d} })^{d/2-1} }) \right ] [sin((10000d2)0n),cos((10000d2)0n),sin((10000d2)1n),cos((10000d2)1n),...,sin((10000d2)d/21n),cos((10000d2)d/21n)]此时不妨令其中 β = 1000 0 2 / d \beta=10000^{2/d} β=100002/d,则上式可以改写为:
[ sin ⁡ ( n β 0 ) , cos ⁡ ( n β 0 ) , sin ⁡ ( n β 1 ) , cos ⁡ ( n β 1 ) , . . . , sin ⁡ ( n β d / 2 − 1 ) , cos ⁡ ( n β d / 2 − 1 ) ] \left [ \sin(\frac{n}{\beta^{0} }),\cos(\frac{n}{\beta^{0} }), \sin(\frac{n}{\beta^{1} }),\cos(\frac{n}{\beta^{1} }), ...,\sin(\frac{n}{\beta^{d/2-1} }),\cos(\frac{n}{\beta^{d/2-1} }) \right ] [sin(β0n),cos(β0n),sin(β1n),cos(β1n),...,sin(βd/21n),cos(βd/21n)]这时,我们就可以将位置编码和进制转换联系起来了:

  1. 二者都含有 n β i \frac{n}{\beta^{i}} βin
  2. 二者都对 n β i \frac{n}{\beta^{i}} βin进行了一个周期操作:求余操作含有周期性,sin和cos函数也具有周期性
  3. 去除掉取整这个无关紧要的操作后,位置编码就和进制编码操作一致

因此我们可以说,位置编码本质就是求位置 n n n β \beta β进制数。有了这个发现后,为了让我们的编码方案能够编码之前的 S S S倍长度,那么需要被编码的数就变成之前最大位置索引的 S S S倍。在维度大小 d d d不变的情况下,为了编码更大的数,我们只能将进制扩大。这很好理解,假设一开始为 β = 2 \beta=2 β=2进制,那么长度为3的向量最多编码到111也就是7,而10进制却能编码到999位置。

那么我们就假设对应的进制要从 β \beta β变成 k β k\beta kβ,其中 k > 1 k > 1 k>1。为了对低频也就是高维度进行内插,我们考虑最高维度对应的 n β d / 2 − 1 \frac{n}{\beta^{d/2-1}} βd/21n,进制扩大 k k k倍后变成 n ( k β ) d / 2 − 1 \frac{n}{(k\beta)^{d/2-1}} (kβ)d/21n。为了让这个维度对应到内插操作,我们令:
n ( k β ) d / 2 − 1 = n / S β d / 2 − 1 \frac{n}{(k\beta)^{d/2-1}} = \frac{n/S}{\beta^{d/2-1}} (kβ)d/21n=βd/21n/S上式的含义为:将进制扩大 k k k倍后使用 k β k\beta kβ进制对位置 n n n进行编码等价于使用 β \beta β进制对 n / S n/S n/S(位置内插)位置进行编码,即低频内插。上式解得 k = S 2 / ( d − 2 ) k=S^{2/(d-2)} k=S2/(d2)

接着,为了对高频也就是低维度进行直接外推,我们考虑最低维度对应的 n β \frac{n}{\beta} βn,进制扩大 k k k倍后变成 n k β \frac{n}{k\beta} kβn,由于 d d d很大,所以前边解出来的 k k k趋近于1,那么其实对于高频低维度,就和改变前一样了,也就是直接外推。

综上所述,为了实现高频外推和低频内插,我们只需要将进制从 β = 1000 0 2 / d = b 2 / d \beta=10000^{2/d}=b^{2/d} β=100002/d=b2/d变成 k β = S 2 / ( d − 2 ) ⋅ b 2 / d k\beta=S^{2/(d-2)} \cdot b^{2/d} kβ=S2/(d2)b2/d即可。上面说过,NTK-aware interpolation本质上就是将原始RoPE中的 θ i = b − 2 ( i − 1 ) / d \theta_i = b^{-2(i - 1) / d} θi=b2(i1)/d改为 h ( θ i ) = ( b ⋅ S d / ( d − 2 ) ) − 2 ( i − 1 ) / d h(\theta_i)=(b\cdot S^{d/(d-2)})^ {-2(i-1) / d} h(θi)=(bSd/(d2))2(i1)/d,更本质的区别是将基数 b b b乘以了一个和扩展比 S S S有关的常量 S d / ( d − 2 ) S^{d/(d-2)} Sd/(d2)。此时我们发现,NTK-aware interpolation其实就是将进制从之前的 β \beta β变成了 k β k\beta kβ

NTK-aware interpolation通过一个简单的基数扩大操作,便巧妙实现了高频外推+低频内插。

代码实现

NTK-aware代码链接

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch

import transformers

old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):

    #The method is just these three lines
    max_position_embeddings = 16384
    a = 8 #Alpha value
    base = base * a ** (dim / (dim-2)) #Base change formula

    old_init(self, dim, max_position_embeddings, base, device)

transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init

可以看到核心代码为:

base = base * a ** (dim / (dim-2))

即前边提到的基数扩大操作。

  • 9
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cyril_KI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值