0x0. 前言
上篇文章 flash-linear-attention中的Chunkwise并行算法的理解 根据GLA Transformer Paper(https://arxiv.org/pdf/2312.06635 作者是这位大佬 @sonta)通过对Linear Attention的完全并行和RNN以及Chunkwise形式的介绍理解了Linear Attention的Chunkwise并行算法的原理。但是paper还没有读完,后续在paper里面提出了Gated Linear Attention Transformer,它正是基于Chunkwise Linear Attention的思想来做的,不过仍有很多的工程细节需要明了。这篇文章就来继续阅读一下paper剩下的部分,把握下GLA的计算流程以及PyTorch实现。下面对Paper的第三节和第四节进行理解,由于个人感觉Paper公式有点多,所以并没有对paper进行大量直接翻译,更多的是读了一些部分之后直接大白话一点写一下我对各个部分的理解和总结。这样可能会忽略一些细节,建议读者结合原Paper阅读。
这里需要说明的是,在上篇文章里面介绍到的Chunk并行算法实际上不是GLA这篇Paper首次提出的idea。GLA这篇paper是在工程上极大改进了Chunk并行算法,使得它的效率更高。改进的细节正是paper的第三节和第四节介绍的核心内容。不过我在 https://github.com/sustcsonglin/flash-linear-attention 官方仓库以及Paper给出的GLA算法伪代码中都看到只有一次分块,不太清楚原因。此外,Paper的实验中也没有把GLA Transformer Scale Up到更大的规模,这个可能是受限于算力之类的原因,不过最近看到 https://arxiv.org/abs/2405.18428 和 https://arxiv.org/abs/2405.18425 2篇比较新的Paper都是用GLA重做了一些经典的大模型架构,所以相信它是未来可期的。
0x1. Hardware-Efficient Linear Attention
paper描述了一种名为FLASHLINEARATTENTION的算法,这是一种面向输入/输出且硬件高效的线性注意力算法,它和与FLASHATTENTION相似。这一节讨论在实际高效的实现中需要考虑的硬件方面的问题。
0x1.1 硬件优化的准则
一个高效的算法应考虑现代硬件上的计算模型、内存层次结构和专用计算单元。
- Occupancy GPU有许多并行执行的线程;这些线程被分组为线程块,并在流式多处理器(SM)上执行。为了保持高GPU占用率(即GPU资源的使用比例),需要使用足够数量的SM。在大规模训练和长序列建模场景中,批处理大小往往较小,通过序列维度并行化可以实现高GPU占用率。
- 专用计算单元 用于神经网络训练的现代硬件通常具有专用计算单元(例如NVIDIA GPU上的Tensor Core,TPU上的矩阵乘法单元),这些单元可以显著加速矩阵乘法。例如,在A100 GPU上,半精度矩阵乘法在Tensor Core上的速度大约是CUDA Core的16倍。利用这些专用单元对于训练大规模神经网络尤为重要。
- 内存层次结构 GPU具有内存层次结构,包括较大但速度较慢的全局GPU内存(高带宽内存,HBM)和较小但速度较快的共享内存(SRAM)。因此,优化SRAM的利用以减少HBM的I/O成本可以显著提高速度。
0x1.2 针对Linear Attention的硬件考虑
我自己总结下,这一节主要是对递归形式,并行形式,以及Chunkwise并行形式进行了再次说明,paper中提到对于递归形式来说虽然flops较低但是由于要在时间步上频繁访问HBM并且无法使用Tensor Core导致实际效率很低。而对于并行形式来说,它的效率可以做到和FLASHATTENTION一致,但是当序列长度很长时,训练成本会快速增加。最后,对于Chunk形式的并行,它可以利用上Tensor Core,但是之前提出的一些实现效率较低,比如在2k-4k序列长度下是比FLASHATTENTION更慢的。
0x1.3 FLASHLINEARATTENTION: 具有块状形式的硬件高效线性注意力
这一节直接读paper还不是很好懂,其实讲的就是说FLASHLINEARATTENTION算法有一个materialize参数来控制是否要重计算S,然后在计算过程中无论是否要重计算S都会遵循分块加载Q,K,V到共享内存中,然后我们就可以重用共享内存上的块状Tensor来避免多次加载HBM I/O。例如,对于Algorithm1中的materialize为True的情况,当 Q [ n ] Q[n] Q[n] 被加载到SRAM时, Q [ n ] S Q[n]S Q[n]S 和 ( Q [ n ] K T [ n ] ⊗ M V ) [ n ] (Q[n]K^T[n] \otimes MV)[n] (Q[n]KT[n]⊗MV)[n] 可以在芯片上计算,这样可以避免再次加载 Q [ n ] Q[n] Q[n](从而节省HBM I/O)。
对于materialize为False的情况(非重计算版本),算法首先在HBM中把块间递归的结果存下来(对应Paper里的方程2),然后将所有 S [ n ] S[n] S[n](对所有 n ∈ [ N ] n \in [N] n∈[N])都并行计算在HBM中。该方法有更好的并行性,但略微增加了内存占用。非重计算版本顺序计算 S [ n ] S[n] S[n](对所有 n ∈ [ N ] n \in [N] n∈[N]),并使用SRAM暂时存储 S [ n ] S[n] S[n]。这种策略在内存上更高效,但缺乏序列级别的并行性。然而,在后向Pass过程中重计算隐藏状态 S [ n ] S[n] S[n] 会引入大约30%的多余FLOPs。因此,非重计算版本通常比重计算版本速度更慢,但节省了更多GPU内存。
图1展示了这两种方法。
这张图画得挺好的,我们可以清楚的看到对于materialize为False的情况下,Q,K,V都是从HBM中加载到SRAM,每次都会计算出一个新的隐藏状态S出来,注意这个S无需保存所以它一直存在于SRAM上,整体的计算过程是一个Sequential的。而对于materialize为True的情况,首先通过KV计算出S并将S保存到HBM中,这部分也是Sequence的。计算完S之后就可以Chunk并行的计算出 O i O_{i} Oi。这里的箭头表示每个操作需要的操作数,和上文的公式是完全对得上的。
图2展示了FLASHLINEARATTENTION实现的速度和内存占用情况。两种版本的FLASHLINEARATTENTION都比FlashAttention-2(Dao, 2023)和纯PyTorch(即不I/O感知)实现的chunkwise线性注意力快得多,展示了I/O感知的好处。所有方法都具有线性空间复杂度。非重计算版本具有最小的内存占用,而重计算版本的内存占用略高于FlashAttention-2。
0x2. Gated Linear Attention
方程1中的线性递归没有衰减项或遗忘门,而这在RNN中已被证明是至关重要的。缺少衰减项使得模型难以“忘记”信息,这被假设为部分导致线性注意力在长上下文任务中不稳定的原因。最近的研究通过在线性注意力中加入一个全局的、与数据无关的衰减因子 γ ∈ ( 0 , 1 ) \gamma \in (0, 1) γ∈(0,1) 获得了更好的性能: S t = γ S t − 1 + k t T v t S_t = \gamma S_{t-1} + k^T_t v_t St=γSt−1+ktTvt。使用单一的 γ \gamma γ 旨在保持注意力样式的并行形式,以实现高效训练。在paper中,作者考虑了一种与数据相关的门控机制用于线性注意力。我们展示了尽管有一个更具表达力的门控因子,所得到的门控线性注意力(GLA)层仍然可以采用硬件高效的chunkwise方式进行高效训练。
0x2.1 GLA的递归和并行形式
递归形式。GLA 有一个二维遗忘门 G t ∈ ( 0 , 1 ) d × d G_t \in (0,1)^{d \times d} Gt∈(0,1)d×d:
S t = G t ⊙ S t − 1 + k t T v t , S_t = G_t \odot S_{t-1} + k_t^T v_t, St=Gt⊙St−1+ktT