新一代注意力机制Lightning Attention-2

 Datawhale分享 

推荐:王茂霖,Datawhale成员

大语言模型的序列长度的限制,极大的制约了大语言模型在人工智能领域的应用,比如多轮对话、长文本理解、多模态数据的处理与生成等。造成这一限制的根本原因在于当前大语言模型均采用的Transformer架构有着相对于序列长度的二次计算复杂度。这意味着随着序列长度的增加,需要的计算资源成几何倍数提升。如何高效地处理长序列一直是大语言模型的挑战之一。

之前的方法往往集中在如何让大语言模型在推理阶段适应跟长的序列。比如采用Alibi或者类似的相对位置编码的方式来让模型自适应不同的输入序列长度,亦或采用对RoPE等类似的相对位置编码进行差值的方式,在已经完成训练的模型上再进行进一步的短暂精调来达到扩增序列长度的目的。这些方法只是让大模型具有了一定的长序列建模能力但实际训练和推理的开销并没有减少。

OpenNLPLab团队尝试一劳永逸地解决大语言模型长序列问题。他们提出并开源了Lightning Attention-2, 一种新型的线性注意力机制让长序列的训练和推理成本与1K序列长度的一致。在遇到显存瓶颈之前,无限地增大序列长度并不会对于模型训练速度产生负面影响。这让无限长度预训练成为了可能。同时,超长文本的推理成本也与1K Tokens的成本一致甚至更少,这将极大地减少当前大语言模型的推理成本。如下图所示,在400M、1B、3B的模型大小下,随着序列长度的增加,FlashAttention2加持的LLaMA的训练速度开始快速下降,然而Lightning Attention-2 加持的TansNormerLLM的速度几无变化。

7a1c47fa15716ba28417b04e280e343e.png

图1

Lightning Attention-2被知名 AI 博主 AK 转发,并入选Hugging Face 每日必读论文Daily Papers之一 :

cc0ff071bd9a67b114cd8ab2771f1f3c.png

开源地址:https://github.com/OpenNLPLab/lightning-attention

Lightning Attention-2简介

让大模型的预训练速度在不同序列长度下保持一致,这听起来是一个不可能任务。事实上,如果一个注意力机制的计算复杂度相对于序列长度保持线性关系的话,就可以实现这一点。自2020年 线性注意力 横空出世以来,研究人员一直在为了线性注意力的实际效率符合它的理论线性计算复杂度而努力。在2023年之前,大多数的关于线性注意力的工作均集中在对齐它们与Transformer的精度上。终于在2023年中期,改进的线性注意力机制在精度上可以与最先进的Transformer架构对齐。然而,线性注意力中将计算复杂度变成线性的最关键的“左乘变右乘”的计算Trick (如下图所示),在实际实现中远慢于直接左乘的算法。其原因在于右乘的实现需要用到包含大量循环操作的累积求和(cumsum),大量的IO操作使得右乘的效率远低于左乘。

b6a0b22bb269bab3bb8fbef7d5a143c4.png图 2

为了更好地理解Lightning Attention-2 的思路,让我们先回顾下传统softmax attention 的计算公式:O=softmax((QK^T)⊙M)V,其中Q, K, V, M, O 分别为query, key, value, mask和输出矩阵,这里的M在单向任务(如GPT)中是一个下三角的全1矩阵,在双向任务(如Bert)中则可以忽略,即双向任务没有mask矩阵。我们将Lightning Attention-2 的整体思路总结为以下三点进行解释:

  1. Linear Attention 的核心思想之一就是去除了计算成本高昂的 softmax 算子,使 Attention 的计算公式可以写为O=((QK^T)⊙M)V。但由于单向任务中mask 矩阵 M的存在,使得该形式依然只能进行左乘计算,从而不能获得O(N)的复杂度。但对于双向任务,由于没有mask矩阵,Linear Attention 的计算公式可以进一步简化为O=(QK^T)V。Linear Attention 的精妙之处在于,仅仅利用简单的矩阵乘法结合律,其计算公式就可以进一步转化为:O=Q(K^T V),我们将这种计算形式称之为右乘,相对应的前者为左乘。通过图 2 可以直观地理解到Linear Attention 在双向任务中可以达到诱人的O(N)复杂度!

  2. 但是随着decoder-only 的 GPT 形式的模型逐渐成为 LLM 的事实标准,如何利用Linear Attention 的右乘特性加速单向任务成为了亟待解决的难题。为了解决这个问题,我们提出了利用“分而治之”的思想,将注意力矩阵的计算分为对角阵和非对角阵两种形式,并采用不同的方式对他们进行计算。如图 3 所示,Linear Attention-2 利用计算机领域常用的Tiling思想,将Q, K, V 矩阵分别切分为了相同数量的块(blocks)。其中 block 自身(intra-block)的计算由于 mask 矩阵的存在,依然保留左乘计算的方式,具有O(N^2)的复杂度;而block 之间(inter-block)的计算由于没有 mask 矩阵的存在,可以采用右乘计算方式,从而享受到O(N)的复杂度。两者分别计算完成后,可以直接相加得到对应第i块的Linear Attention 输出Oi。同时,对 KV 的状态进行累加以在下一个block 的计算中使用。这样我们就得到了整个Lightning Attention-2 的算法复杂度为intra-block 的O(N^2)和inter-block 的O(N)的Trade-off。怎么取得更好的Trade-off 则是由Tiling 的block size 决定的。

  3. 细心的读者会发现,以上的过程只是Lightning Attention-2 的算法部分,之所以取名Lightning 是因为我们充分考虑了该算法过程在GPU硬件执行过程中的效率问题。受到 FlashAttention 系列工作的启发,实际在 GPU 上进行计算的时候,我们将切分后的Qi, Ki, Vi 张量从 GPU 内部速度更慢容量更大的 HBM 搬运到速度更快容量更小的 SRAM 上进行计算,从而减少大量的memory IO 开销。当该block完成Linear Attention 的计算之后,其输出结果 Oi 又会被搬回至 HBM。重复这个过程直到所有block 被处理完毕即可。

想要了解更多细节的读者可以仔细阅读本文中的Algorithm 1 和 Algorithm 2,以及论文中的详细推导过程。Algorithm 以及推导过程都对Lightning Attention-2 的前向和反向过程进行了区分,可以帮助读者有更深入的理解。

7b15ed7cc0a3127d9a390f338a708216.png图 3

887635383bec82361c0cdce204156c65.png 10ba9ef4695bb40a14473a50f7ffc03a.png

Lightning Attention-2精度对比

研究人员首先在小规模(400M)参数模型上对比了Lightning Attention-2与Lightning Attention-1的精度区别,如下图所示,二者几无差别。

101438f3c1d4a17a949c62c1d40e9918.png

随后研究人员在1B、3B上将Lightning Attention-2加持的TransNormerLLM(TNL-LA2)与其它先进的非Transformer架构的网络以及FlashAttention2加持的LLaMA在相同的语料下做了对比。如下图所示,TNL-LA2与LLaMA保持了相似的趋势,并且loss的表现更优。这个实验表明,Lightning Attention-2在语言建模方面有着不逊于最先进的Transformer架构的精度表现。

8480caa25f431582d4f3ed28651ccab0.png

在大语言模型任务中,研究人员对比了TNL-LA2 15B与Pythia在类似大小下的大模型常见Benchmark的结果。如下表所示,在吃掉了相同tokens的条件下,TNL-LA2在常识推理和多项选择综合能力上均略高于基于Softmax 的注意力的Pythia模型。

e26eaf65c41d1bcde0b492308a652963.png

Lightning Attention-2速度对比

研究人员对Lightning Attention-2与FlashAttention2进行了单模块速度与显存占用对比。如下图所示,相比于Lightning Attention-1和FlashAttention2,在速度上,Lightning Attention-2表现出了相比于序列长度的严格线性增长。在显存占用上,三者均显示出了类似的趋势,但Lightning Attention-2的显存占用更小。这个的原因是FlashAttention2和Lightning Attention-1的显存占用也是近似线性的。

0ff241b2a8e105d574206e56ab808fb2.png

笔者注意到,这篇文章主要关注点在解决线性注意力网络的训练速度上,并实现了任意长度的长序列与1K序列相似的训练速度。在推理速度上,并没有过多的介绍。这是因为线性注意力在推理的时候可以无损的转化为RNN模式,从而达到类似的效果,即推理单token的速度恒定。对于Transformer来说,当前token的推理速度与它之前的token数量相关。

笔者测试了Lightning Attention-1加持的TransNormerLLM-7B与常见的7B模型在推理速度上的对比。如下图所示,在近似参数大小下,Lightning Attention-1的吞吐速度是百川的4倍,ChatGLM的3.5倍以上,显示出了优异的推理速度优势。

16f0c76a57b5d350279d2518c0e9806b.png

TransNormerLLM 15B(集成Lightning Attention-2)的最新Benchmark结果

613b3676fd5ddc376edd18ab2f68db43.png

小结

Lightning Attention-2代表了线性注意力机制的重大进步,使其无论在精度还是速度上均可以完美地替换传统的Softmax 注意力,为今后越来越大的模型提供了可持续扩展的能力,并提供了一条以更高效率处理无限长序列的途径。OpenNLPLab团队在未来将研究基于线性注意力机制的序列并行算法,以解决当前遇到的显存屏障问题。

d9b8b866d6f96f4853b7197248556538.png

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值