Efficient Streaming Language Models with Attention Sinks

Meta、麻省理工和卡耐基梅隆大学提出了一种高效流式语言模型,可以让语言模型能够支持几乎无限的上下文窗口。

paper:https://arxiv.org/abs/2309.17453
github: https://github.com/mit-han-lab/streaming-llm

evicted:驱逐 sinks:池、汇聚

研究人员想要解决的问题是:
能否在不牺牲效率和性能的情况下让大预言模型支持无限长度的上下文?

研究人员在尝试解决无限上下文长度输入时发现,限制模型能力最主要的原因是这两个:

  1. 在解码阶段,基于Transformer的 LLM 会缓存之前所有token的键值状态(KV),如下图1(a)所示,这可能会导致内存使用过多并增加解码延迟。

  2. 现有模型的长度外推能力有限,当序列长度超出预训练期间设置的注意力窗口大小时,模型性能会严重下降。

在这里插入图片描述图1

  • a.密集注意力。具有O(T2)的时间复杂度和不断增加的缓存大小。当预测文本T远远大于预训练文本长度L(T>>L),困惑度上升性能下降。深蓝色方块是不停新增token计算注意力,重新计算softmax, 然后cache到内存中(浅蓝色方块)。直达达到L是新一个token就会开始效果变差。
    简单思想可以理解只需要计算深蓝色score所在行,cache浅蓝色后,softmax重新更新当前token向量(最后一行)
  • b.窗口注意力。只维护最新的token的KV状态的固定大小的滑动窗口。问题很明显虽然在缓存最初填满后确保了恒定的内存使用和解码速度,但一旦序列长度超过缓存大小,即使只是逐出第一个token的KV,模型也会崩溃。
  • c.重新计算滑动窗口。发放为每个生成的token重建最新token的KV状态(这样一直保持有初始token)。虽然它在长文本上表现良好,但它的O(T L2)复杂性(源于上下文重新计算中的二次注意力)使得它相当慢(流应用时不行)。
  • d.StreamingLLM。保留attention sink也就是注意力池黄色方块(several initial tokens) 与最近的token结合,用于稳定的注意力计算。实验证明是高效的并且在扩展文本上提供稳定的性能。

Perplexities困惑度 are measured using the Llama-2-13B model on the first book (65K tokens) in the PG-19 test set

在探索b中窗口注意力方法为什么行不通的过程中,研究人员发现了自回归LLM的一个有趣的特征:大量的注意力分数被分配给初始token,而不管这些token与语言建模任务的相关性如何,研究人员将这些**占用了大量注意力的初始token称为——「注意力池(attention sinks)」**如下图2.

尽管这些初始token很多时候缺乏语义含义,但它们却占用了很高的注意力分数(attention scores)。
研究人员认为主要原因是因为Softmax操作,要求所有上下文token的注意力分数总和为1。因此,即使当前任务和许多先前的token不匹配,模型仍然需要在某个地方分配这些不需要的注意力分数,使得分数总和为1。

初始token为什么会获得如此之高的注意力,原因也很简单:
初始token对几乎所有后续token都是可见的,因为自回归语言模型的性质,使初始token更容易被训练为「注意力池」。

在这里插入图片描述
图2:Llama-2-7B中256个句子的平均注意力逻辑的可视化,每个句子的长度为16。观察结果包括:(1)前两层(第0层和第1层)的注意力图呈现出“局部”模式,最近的tokens受到了更多的关注。(2) 在底部两层之外,模型在所有层和头部都非常关注初始令牌

在SoftMax函数中,是e指数函数,这意味着即使输入的初始令牌 在语义上与语言建模不相关,由于指数函数的存在,SoftMax 函数的输出中它仍然会有一个非零的值。因此,模型在进行自注意力机制时,即使当前的嵌入已经包含了足够的自包含信息用于预测,模型仍然需要从其他头和层中的其他令牌中汇聚一些信息。结果,模型倾向于将不必要的注意力值“倾泻”到特定的令牌上,这就是所谓的attention sink(注意力汇聚)。【我的理解后续的token都以初始token作为注意力参照,这种注意力汇聚使得第一个token更加关注】

概括:
1、局部模式在前两层中的呈现: 在第一层和第二层(layers 0 和 1),注意力图呈现出“局部”模式,即对最近的令牌给予更多的关注。这表明模型在初步的处理阶段更注重周围的令牌,强调了局部上下文的重要性。
2、模型跨所有层和头部(heads)都强烈关注初始令牌: 在底部两层之外,也就是深层网络中,模型在所有层和头部都倾向于强烈关注初始令牌。这与“attention sink”现象相吻合,即模型在处理过程中过度关注初始令牌,而 SoftMax 函数的特性是导致这种现象的原因之一。

StreamingLLM利用了注意力池具有高注意力值的事实,保留它们可以保持注意力得分分布接近正常。因此,StreamingLLM只需将注意力汇聚令牌的KV(只有4个初始令牌就足够了)与滑动窗口的KV保持在一起,就可以锚定注意力计算并稳定模型的性能。

并证明语言模型可以预先训练为只需要一个注意力池token进行流部署。具体来说,我们建议在所有训练样本开始时使用一个额外的可学习标记可以作为指定的注意力池。\

为什么窗口注意力不行?

在这里插入图片描述

上图显示了 20K token文本的语言建模的复杂性。很明显,当文本长度超过缓存大小时,由于排除初始token而导致困惑度度激增。这表明初始token,无论与预测token的距离如何,对于维持 LLM 的稳定性至关重要。

为什么LLM在删除初始token的KV时会崩溃?

研究人员将 Llama-2-7B 和模型的所有层和头部的注意力图可视化,如图 2 所示。研究人员发现,除了底部两层之外,模型始终关注所有层和头部的初始token。含义很明确:**删除这些初始token的 KV 将删除注意力计算中 SoftMax 函数中相当一部分分母。**这种改变导致注意力分数的分布发生显着偏离正常推理设置中的预期。

在这里插入图片描述
对于初始标记在语言建模中的重要性,有两种可能的解释:(1)它们的语义至关重要,或者(2)模型学习到对它们的绝对位置的偏见。
为了区分这些可能性,我们进行了实验(表1),其中前四个令牌被换行令牌“\n“.观察结果表明,该模型仍然显著强调这些初始换行符标记。此外,重新引入它们将语言建模的困惑恢复到与原始初始标记相当的水平。这表明,起始标记的绝对位置,而不是其语义值,具有更大的意义

在这里插入图片描述
表1、initial four tokens可以恢复困惑度

使用4个tokens才有效。我们认为,出现这种模式是因为这些模型在预训练期间没有在所有输入样本中包含一致的起始标记。尽管Llama-2确实在每个段落前面加了一个"< s >"token,它是在文本分块之前应用的,导致一个主要是随机的token占据第零个位置。

在这里插入图片描述

为什么语言模型(LLMs)在注意力机制中存在过度关注初始令牌(initial tokens)的现象?

补充:softmax本质阻止所有参与的token具有零值。即使当前嵌入有足够的自包含信息用于其预测,也需要聚合来自所有层中所有头的其他token的一些信息。

由于自回归语言模型的顺序性质,初始token对所有后续token都是可见的,而后面的token仅对有限的一组后续token可见。因此,初始token更容易被训练为注意力集中器,捕获不必要的注意力。

1、自回归性质。根据之前生成的令牌来生成下一个令牌。在这个过程中,初始令牌是最早生成的令牌,因此在生成整个序列的过程中,它对所有后续令牌都是可见的。
2、可见性的不对称性:后续令牌只能被生成的一小部分令牌所看到。这种不对称性导致了对初始令牌更强烈的关注。
3、训练中的注意力聚焦点: 由于模型在训练过程中学会了将注意力集中在初始令牌上,这些令牌更容易成为“attention sink”,即吸引不必要的注意力。这可能是因为初始令牌在训练中更频繁地与后续令牌发生交互,从而更容易捕捉到一些模型认为重要的信息

具有注意力sinks的滚动KV cache

为了在已经训练好的LLM中实现LLM流,我们提出了一种简单的方法,可以在没有任何模型微调的情况下恢复窗口注意力的困惑。
除了当前的滑动窗口令牌,我们在注意力计算中重新引入了一些起始令牌的KV。StreamingLLM中的KV缓存在概念上可以分为两部分,如图4所示:(1)注意力汇(四个初始令牌)稳定注意力计算;2) 滚动KV缓存保留了最新的令牌,这对语言建模至关重要。StreamingLLM的设计是通用的,可以无缝地结合到任何采用相对位置编码的自回归语言模型中,如RoPE(Su等人,2021)和ALiBi(Press等人,2022)。

在这里插入图片描述
图4、The KV cache of StreamingLLM。黄色为注意力sinks,灰色是驱逐的tokens, 蓝色是缓存kv,红色是最新的token。

在确定相对距离并向标记添加位置信息时,**StreamingLLM关注缓存内的位置,而不是原始文本中的位置。这种区别对于StreamingLLM的性能至关重要。**例如,如果当前高速缓存具有令牌[0,1,2,3,6,7,8]并且正在解码第9个令牌的过程中,则分配的位置是[0,1,2,3,4,5,6,7],而不是原始文本中的位置,该位置将是[0,1,2,3,6,7,8,9]。

如何卸载模型过度关注初始令牌的注意力得分?

引入一个专门的“汇聚标记(可学习的占位符令牌)”(sink token)来卸载过多的注意力得分。由于这一点,模型无意中将全局可见的令牌,主要是初始令牌,作为注意力的聚焦点。提出的潜在解决方案有两个:

1、引入全局可训练的注意力汇聚标记(Sink Token): 有意地引入一个全局可训练的注意力汇聚标记,即“Sink Token”。这个标记的作用是作为一个储存不必要注意力分数的仓库。通过这种方式,模型可以有一个指定的位置,用于处理额外的注意力,避免过度关注全局可见的令牌,特别是初始令牌。

2、替代传统的 SoftMax 函数: 用 变体函数SoftMax1 替代传统的 SoftMax 函数。变体的公式如下:
在这里插入图片描述
softmax中这些概率支持一个键 - 值查找的连续值版本,因为qk算出atten,在softmax乘以v。本质上就是k-v键值查找。

分母上加 1 将改变注意力单元,不再使用真实的权重概率向量,同时有一个新的选项来提供 all-low 权重,这意味着它可以选择不对任何事情具有高置信度。

在这里插入图片描述
图3.Vanilla用了标准的sofamax注意力,Zero Sink用了Softmax1等价实现为token有着所有0的key和value特征。还有一种在softmax1上所有训练样本中预先考虑可学习的占位符token(Sink Token)。(很奇怪还是只证明了几个初始token的sinks重要性,引入sink token对于稳定注意力机制非常有效)

Remarkably, the vanilla model requires the addition of multiple tokens as attention sinks to maintain stable streaming perplexity. In contrast, the model trained with a sink token achieves satisfactory streaming
performance using just the sink token

为了验证,我们在相同的设置下从头开始用1.6亿个参数预训练三种语言模型。第一个模型使用标准的SoftMax注意力(Vanilla),第二个模型用SoftMax1(Zero Sink)取代了常规注意力机制,并且在所有训练样本中准备了可学习的占位符令牌(Sink token)。如表3所示,虽然零下沉在一定程度上缓解了注意力下沉问题,但该模型仍然依赖于其他初始令牌作为注意力下沉。引入汇聚令牌在稳定注意力机制方面非常有效。
简单地将这个sink令牌与最近的令牌配对就足以锚定模型的性能,由此产生的评估困惑甚至略有改善。鉴于这些发现,我们建议在所有样本中使用接收令牌来训练未来的LLM,以优化流部署。

实验

我们默认使用四个初始令牌作为注意力汇。
在这里插入图片描述
图5:StreamingLLM在不同LLM家族和模型尺度上对具有400万个标记的超长文本的语言建模困惑。困惑始终保持稳定。我们使用PG19(100本书)的连接测试集来进行语言建模,其中困惑波动可归因于书籍之间的转换。
在这里插入图片描述
图6、带/和不带汇点令牌的模型的训练前损失曲线。两种模型有相似的趋同趋势。
在这里插入图片描述
表4、零样本准确率。在预训练期间包含接收器令牌不会损害模型性能。

在这里插入图片描述
图7、这些可视化结果基于 256 个句子,每个句子包含 16 个令牌。左边是使用 Sink Token 的模型,右边是没有使用 Sink Token 的模型(论文画的标注不对)。两个图表显示相同的层次和头部。

1、右边没有 Sink Token 的情况: 在没有 Sink Token 的模型中,底层显示出局部关注,而在更深层次上,模型更加关注初始令牌。即模型容易在深层次上过度集中注意力在初始令牌上。

2、左边有 Sink Token 的情况: 在有 Sink Token 的模型中,可以清晰地看到在所有层次上都有关注它的明显现象,有效地聚集冗余的注意力。这表明 Sink Token 成功地成为一个注意力的集中点。

有 Sink Token 时对其他初始令牌的关注减少: 在存在 Sink Token 的情况下,相对较少的注意力被分配给其他初始令牌,支持了将 Sink Token 指定为提高流式性能的设计优势

在这里插入图片描述
如图8所示,与LongEval在大跨度设置上的单个查询不同,我们每10行新信息查询一次模型。每个查询的答案都是20行之前的,反映了真实世界中问题通常与最近的信息有关的情况。
在这里插入图片描述

如图9所示,即使输入长度接近120K令牌,采用StreamingLLM的LLM也能保持合理的准确性。相反,密集注意力和窗口注意力分别在预训练文本长度和KV缓存大小下失败。此外,我们使用了两个上下文扩展模型,LongChat-7b-v1.5-32k和Llama-2-7b-32KInstruct,以表明StreamingLLM可以补充上下文扩展技术。
在StreamingLLM中,上下文扩展意味着扩大流式LLM的最大缓存大小,从而能够捕获更广泛的本地信息。

在这里插入图片描述
在这里插入图片描述
图10.图10:带有重新计算基线的滑动窗口方法和StreamingLLM之间的每个令牌解码延迟和内存使用情况的比较,与X轴上的缓存大小(注意力窗口大小)进行了比较。StreamingLLM为每个令牌提供了高达22.2倍的显著加速,并保留了类似于重新计算基线的内存占用。

在这里插入图片描述
我们评估了缓存大小对StreamingLLM困惑的影响。**与直觉相反,增加缓存大小并不能始终降低语言建模的困惑。**这种不一致性表明了一个潜在的局限性,即这些模型可能无法最大限度地利用它们接收到的整个上下文。

通过引入“attention sinks”与最近的令牌配对,能够高效地处理长度达 4 百万令牌的文本。还通过使用具有专门的 sink token 的预训练模型,以此提高流式应用部署的性能。

缺点

1、模型无法概括超长文本也没有增强它们的长期记忆【也就是如果之前内容不重要还行,如果非常相关就有点惨了】,能输入冗长文本,只是在于从最近的token生成流畅的文本而不需要刷新缓存。
2、StreamingLLM通过只保留最新的令牌和注意力汇,丢弃中间令牌来解决这一问题。这使模型能够在不重置缓存的情况下从最近的令牌生成连贯的文本——这是早期方法中没有的功能。但本质上还是有窗口的限制,比如最大4096

早期的方法要么在会话长度超过训练长度时需要重置缓存(丢失最近的上下文),要么根据最近的文本历史重新计算KV状态,这可能很耗时。streaming对流式对话部署很有优势。

attention_sink_window_size: 2048 - group * attention_sink_size
如果采用MGA多组查询注意力机制,那么sinks就是group*4个初始需要保存,窗口大小也要缩减。但感觉分组影响不大。

参考:
1、https://mp.weixin.qq.com/s/xO57JYWKkBZ_PQ1b9eVYMg
2、https://mp.weixin.qq.com/s/Xjvg_ifh5lPkoQ2gkhY2BQ

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值