大模型训练之加速篇 -attention优化【稀疏attention->线性化attention->分块计算->简化attention->Transformer-VQ】

0 说在开头,线性attention的必要性

Attention和FFN的复杂度:对于base版来说,当序列长度不超过1536时,Transformer的复杂度都是近乎线性的;
当序列长度超过1536时,Transformer的计算量逐渐以Attention为主,复杂度慢慢趋于二次方,直到长度超过4608,才真正以二次项为主

从已有的各个高效Attention的工作中,我们可以得出结论:这些改进工作所关心的序列长度主要都是以千为单位的,有明显计算效率提升的序列长度基本上都要好几千;当然,我们前面的讨论主要针对的还是时间复杂度,对于空间复杂度,也就是显存占用量,降低的幅度一般要比时间复杂度提升的幅度的要大,但总体而言都是长序列才有价值。

所以,如果你的序列长度还只是一两百,那么就完全不要期望Attention本身的改进了,老老实实换个小点的模型就好。你可以期望未来会有更小的模型能达到同样好的效果,但是不要期望同样大的模型通过修改Attention来提升效率,因为说白了,就算把Attention完全去掉,也提升不了多少性能。

稀疏attention:sprse, reformer, linformer
线性化attention:去softmax/performer近似QK
分块计算:MQA
内存优化 flash attention

众所周知,尽管基于Attention机制的Transformer类模型有着良好的并行性能,但它的空间和时间复杂度都是O(n2)级别的,n是序列长度,所以当n比较大时Transformer模型的计算量难以承受。近来,也有不少工作致力于降低Transformer模型的计算量,比如模型剪枝、量化、蒸馏等精简技术,又或者修改Attention结构,使得其复杂度能降低到O(nlogn)甚至O(n)。

1 稀疏Attention

1.1 OpenAI的Sparse Attention

通过“只保留小区域内的数值、强制让大部分注意力为零”的方式,来减少Attention的计算量。经过特殊设计之后,Attention矩阵的大部分元素都是0,因此理论上它也能节省显存占用量和计算量。后续类似工作还有《Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection》、《Longformer: The Long-Document Transformer》等。

但是很明显,这种思路有两个不足之处:
1、如何选择要保留的注意力区域,这是人工主观决定的,带有很大的不智能性;
2、它需要从编程上进行特定的设计优化,才能得到一个高效的实现,所以它不容易推广。

1.2 Reformer

Reformer也是有代表性的改进工作,它将Attention的复杂度降到了𝒪(nlogn)。某种意义上来说,Reformer也是稀疏Attention的一种,只不过它的稀疏Pattern不是事先指定的,而是通过LSH(Locality Sensitive Hashing)技术(近似地)快速地找到最大的若干个Attention值,然后只去计算那若干个值。此外,Reformer通过构造可逆形式的FFN(Feedforward Network)替换掉原来的FFN,然后重新设计反向传播过程,从而降低了显存占用量。

所以,相比前述稀疏Attention,Reformer解决了它的第一个缺点,但是依然有第二个缺点:实现起来复杂度高。要实现LSH形式的Attention比标准的Attention复杂多了,对可逆网络重写反向传播过程对普通读者来说更是遥不可及~

1.3 Linformer

它依然保留原始的Scaled-Dot Attention形式,但在进行Attention之前,用两个m×n的矩阵E,F分别对K,V进行投影,即变为
Attention(Q,K,V)=softmax(Q(EK)⊤)FV

这样一来,Q(EK)⊤就只是一个n×m的矩阵,而作者声称对于哪怕对于很大的序列长度n,m也可以保持为一个适中的常数,从而这种Attention也是线性的。跟Linformer类似的思路还出现在更早一些的CV论文《Asymmetric Non-local Neural Networks for Semantic Segmentation》中。

但是,笔者认为“对于超长序列m可以保持不变”这个结论是值得质疑的,对于长序列原论文只做了MLM任务,而很明显MLM并不那么需要长程依赖,所以这个实验没什么说服力。因此,Linformer是不是真的Linear,还有待商榷。

1.4 自回归生成

Linformer的另一个缺点是EK,FV这两个运算直接把整个序列的信息给“糅合”起来了,所以它没法简单地把将来信息给Mask掉(Causal Masking),从而无法做语言模型、Seq2Seq等自回归生成任务,这也是刚才说的原作者只做了MLM任务的原因。相比之下,本文介绍的几种Linear Attention都能做到这一点。

第一种:
在这里插入图片描述

这说明这种Attention可以作为一个RNN模型用递归的方式实现,它的空间复杂度最低,但是要串性计算,适合预测解码时使用;

第二种
直接将φ(K),V∈ℝn×d做外积,得到一个n×d×d的矩阵,然后对n那一维执行cumsum运算,这样就一次性得到S1,S2,…,Sn了,它的速度最快,但空间占用最大,适合训练时使用,不过很多时候都有d2≫n,一般情况下训练时都很难承受这个空间复杂度,因此多数还是用RNN形式。

1.5 下采样技术

从结果上来看,Linformer的EK,FV就是将序列变短(下采样)了,而将序列变短的一个最朴素的方法就是Pooling了,
所以笔者之前也尝试过把Pooling技术引入到Transformer中去。
近来也有类似的工作发出来,比如IBM的《PoWER-BERT: Accelerating BERT Inference via Progressive Word-vector Elimination》和Google的《Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing》。
除了Pooling之外,其实还有其他的下采样技术,比如可以通过stride > 1的一维卷积来实现,基于这个思路,或许我们可以把FFN里边的Position-Wise全连接换成stride > 1的一维卷积?总之这方面应该也能玩出很多花样来,不过跟Linformer一样,这样糅合之后做自回归生成就很难了

2 线性化Attention(Linear Attention)

2.1 去掉标准Attention中的Softmax

就可以使得Attention的复杂度退化为理想的𝒪(n)级别(Linear Attention)。
相比于其他类似的改进结构的工作,这种修改能在把复杂度降到𝒪(n)的同时,依然保留所有的“token-token“的注意力,同时还能保留用于做自回归生成的可能性。
详见博文:标准self-attention的几个变种的理解【token对token”是必须的吗】【必须有softmax吗】

2.2 近似矩阵计算q*K

一种随机投影方案,可以将标准Attention转化为线性Attention,并保持一定的近似。
理论上来说,只要投影的维度足够大,那么可以足够近似标准Attention。
换句话说,标准Attention可以视作一个无限维的线性Attention。

标准attention:

为了保留Attention相似的分布特性,我们要求sim(qi,kj)≥0恒成立。
在这里插入图片描述

2.2.1 performer

近似目标:

在这里插入图片描述

在这里插入图片描述

看起来Performer是挺不错的,那是不是说我们就可以用它来替代Transformer了?并不是,纵然Performer有诸多可取之处,但仍然存在一些无法解决的问题。

首先,为了更好地近似标准Transformer,Performer的m必须取得比较大,至少是m>d,一般是d的几倍,这也就是说,Performer的head_size要比标准Transformer要明显大。虽然理论上来说,不管m多大,只要它固定了,那么Performer关于序列长度的复杂度是线性的,但是m变大了,在序列长度比较短时计算量是明显增加的。换句话说,短序列用Performer性能应该是下降的,根据笔者的估计,只有序列长度在5000以上,Performer才会有比较明显的优势。

其次,目前看来Performer(包括其他的线性Attention)跟相对位置编码是不兼容的,因为相对位置编码是直接加在Attention矩阵里边的,Performer连Attention矩阵都没有,自然是加不了的。此外,像UniLM这种特殊的Seq2Seq做法也做不到了,不过普通的单向语言模型还是可以做到的。总之,(n2)的Attention矩阵实际上也带来了很大的灵活性,而线性Attention放弃了这个Attention矩阵,也就放弃了这种灵活性了。

最后,也是笔者觉得最大的问题,就是Performer的思想是将标准的Attention线性化,所以为什么不干脆直接训练一个线性Attention模型,而是要向标准Attention靠齐呢?从前面的最后一张图来看,Performer并没有比Linear Transformer有大的优势(而且笔者觉得最后一张图的比较可能有问题,Performer效果可能比Linear Transformer要好,但是速度怎么可能还超过了Linear Transformer?Performer也是转化为Linear Transformer的,多了转化这一步,速度怎能更快?),因此Performer的价值就要打上个问号了,毕竟线性Attention的实现容易得多,而且是通用于长序列/短序列,Performer实现起来则麻烦得多,只适用于长序列。

转载自:
Performer:用随机投影将Attention的复杂度线性化
Transformer升级之路:3、从Performer到线性Attention

2.2.2 泰勒展开

将标准Attention转换为无限维线性Attention的思路,不同于Performer的随机投影,笔者构思的这两种方案都是确定性的,并且能比较方便地感知近似程度。

在这里插入图片描述

2.2.3 指数定义

在这里插入图片描述

2.2.4 对比总结

要说实用价值,后两个确定性的方案远不如Performer的随机投影方案,因为随机投影的输出维度可以比较灵活地控制,而两个确定性的方案输出纬度则是dn级别的,这个通常都远远大于序列长度本身了,所以用它们来做线性Attention效率基本比标准的Attention还差。

不过,从理论上来讲,后两种方案提供了更为简明便捷的思路,让我们将标准Attention跟无限维的线性Attention等价起来。这种等价性通常能帮助我们更好地理解Attention机制,其中最直接的便是关于Attention的秩的理解。

做过线性Attention研究的读者应该知道,如果用线性Attention做双向注意力任务(比如MLM),那么效果下降会很明显,这是因为线性Attention的ϕ(Q),φ(K)∈ℝn×d(d是每个head的head_size),一般有n≫d,所以ϕ(Q)φ(K)⊤得到的n×n的Attention矩阵的秩顶多为d。这就是线性Attention的低秩问题,低秩限制了线性Attention的表达能力。

相反,前述介绍的三种变换,都告诉我们标准Attention可以视为无限维的线性Attention,所以标准Attention的秩理论上就不受限于d,因此同样参数量的标准Attention表现往往表现得比线性Attention好。在《Transformer升级之路:3、从Performer到线性Attention》中我们也说过,如果要将标准Attention切换为线型Attention,那么d也要相应地进行放大,才能要效果保持一定程度上的近似。

3 分块计算

3.1MQA (multi query attention)

Fast Transformer Decoding: One Write-Head is All You Need

MQA 是 19 年提出的一种新的 Attention 机制,其能够在保证模型效果的同时加快 decoder 生成 token 的速度。
那到底能提升多少的速度呢,我们来看论文中给出的结果图[生成每个token消耗的时间ms]:
在这里插入图片描述

从字面上看,Multi Query Attention(MQA) 和 Multi Head Attention(MHA)只差了一个单词,
就是从「Head」变成了「Query」。
我们知道,在 transformer 中是包含若干个注意力头(head)组成的,
而每个 head 又是由: query(Q),key(K),value(V) 3 个矩阵共同实现的。

「参数共享」并不是一个很新奇的思路,在 Albert 里也有通过使用跨层共享参数(Cross-layer parameter sharing)的方式来大大减少 bert 的参数量,具体做法可以参考这里:何枝:基于BERT的几种改进模型
现在,
我们知道了 MQA 实际上是将 head 中的 key 和 value 矩阵抽出来单独存为一份共享参数,
而 query 则是依旧保留在原来的 head 中,每个 head 有一份自己独有的 query 参数。

4 简化attention

4.1 FlashAttention V1

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

在这里插入图片描述

当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这是因为self-attention的time和memory complexity会随着sequence length的增加成二次增长。
标准Attention的中间结果S,P通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为O(N2)。

FlashAttention对HBM访问的次数为O(N2d2M-1)
Attention对HBM访问的次数为O(Nd+ N2)
往往N远远大于d(例如GPT2中N=1024,d=64),因此FlashAttention会快很多。下图展示了两者在GPT-2上的Forward+Backward的GFLOPs、HBM、Runtime对比(A100 GPU):
在这里插入图片描述

GPU中存储单元主要有HBM和SRAM:HBM容量大但是访问速度慢,SRAM容量小却有着较高的访问速度。例如:A100 GPU有40-80GB的HBM,带宽为1.5-2.0TB/s;每108个流式多核处理器各有192KB的片上SRAM,带宽估计约为19TB/s。可以看出,片上的SRAM比HBM快一个数量级,但尺寸要小许多数量级。
综上,FlashAttention目的不是节约FLOPs,而是减少对HBM的访问。重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
对应的计算过程:
每次外循环(outer loop,j)Kj Vj载入的的大小为Bcd=768d,一共循环次Tc=2次
每次内循环(inner loop,i)载入的Qi 的大小为Brd=64d,一共循环Tr=16次(总次数还需要乘以外循环)
Sij = Qi*KjT,即为(下标表示维度):。
Pij,表示和标准attention Pij计算的有区别,因为得到的row_max(Sij)最大值可能不是S第i行的最大值。的大小和一样,都为。
Pij和Sij只是部分结果,如下图所示,外循环是横向(特征维d)移动的,内循环是纵向(序列维N)移动的。换句话说,外循环在顺序计算特征,内循环在顺序计算序列。
Oi的大小为Br*d,第二维d是满的(和最终一样),这意味着每次外循环都要重新更新当前批次中的特征,即虽然第一次外循环P00*V0和第二次外循环P01*V1都会得到O0,但是第二次的O0是基于第一次O0重新生成的。
diag(……)作用是将vector生成为一个对角矩阵,从而实现相同长度的两个vector进行element-wise相乘。

在这里插入图片描述

GPU 知识

在这里插入图片描述
在这里插入图片描述

从Hardware角度来看:
Streaming Processor(SP):是最基本的处理单元,从fermi架构开始被叫做CUDA core。
Streaming MultiProcessor(SM):一个SM由多个CUDA core(SP)组成,每个SM在不同GPU架构上有不同数量的CUDA core,例如Pascal架构中一个SM有128个CUDA core。
SM还包括特殊运算单元(SFU),共享内存(shared memory),寄存器文件(Register File)和调度器(Warp Scheduler)等。register和shared memory是稀缺资源,这些有限的资源就使每个SM中active warps有非常严格的限制,也就限制了并行能力。

从Software(编程)角度来看:
thread:一个CUDA并行程序由多个thread来执行

thread是最基本的执行单元(the basic unit of execution)。

warp:一个warp通常包含32个thread。每个warp中的thread可以同时执行相同的指令,从而实现SIMT(单指令多线程)并行。

warp是SM中最小的调度单位(the smallest scheduling unit on an SM),一个SM可以同时处理多个warp

thread block:一个thread block可以包含多个warp,同一个block中的thread可以同步,也可以通过shared memory进行通信。

thread block是GPU执行的最小单位(the smallest unit of execution on the GPU)。

一个warp中的threads必然在同一个block中,如果block所含thread数量不是warp大小的整数倍,那么多出的那个warp中会剩余一些inactive的thread。也就是说,即使warp的thread数量不足,硬件也会为warp凑足thread,只不过这些thread是inactive状态,但也会消耗SM资源。
grid: 在GPU编程中,grid是一个由多个thread block组成的二维或三维数组。grid的大小取决于计算任务的规模和thread block的大小,通常根据计算任务的特点和GPU性能来进行调整。

Hardware和Software的联系:
SM采用的是Single-Instruction Multiple-Thread(SIMT,单指令多线程)架构,warp是最基本的执行单元,一个warp包含32个并行thread,这些thread以不同数据资源执行相同的指令。
当一个kernel被执行时,grid中的thread block被分配到SM上,大量的thread可能被分到不同的SM上,但是一个线程块的thread只能在一个SM上调度,SM一般可以调度多个block。每个thread拥有自己的程序计数器和状态寄存器,并且可以使用不同的数据来执行指令,从而实现并行计算,这就是所谓的Single Instruction Multiple Thread。
一个CUDA core可以执行一个thread,一个SM中的CUDA core会被分成几个warp,由warp scheduler负责调度。GPU规定warp中所有thread在同一周期执行相同的指令,尽管这些thread执行同一程序地址,但可能产生不同的行为,比如分支结构。一个SM同时并发的warp是有限的,由于资源限制,SM要为每个block分配共享内存,也要为每个warp中的thread分配独立的寄存器,所以SM的配置会影响其所支持的block和warp并发数量。

GPU执行模型小结:
GPU有大量的threads用于执行操作(an operation,也称为a kernel)。这些thread组成了thread block,接着这些blocks被调度在SMs上运行。在每个thread block中,threads被组成了warps(32个threads为一组)。一个warp内的threads可以通过快速shuffle指令进行通信或者合作执行矩阵乘法。在每个thread block内部,warps可以通过读取/写入共享内存进行通信。每个kernel从HBM加载数据到寄存器和SRAM中,进行计算,最后将结果写回HBM中。

4.2 FlashAttention V2

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
FlashAttention利用GPU非匀称的存储器层次结构,实现了显著的内存节省(从平方增加转为线性增加)和计算加速(提速2-4倍),而且计算结果保持一致。但是,FlashAttention仍然不如优化的矩阵乘法(GEMM)操作快,只达到理论最大FLOPs/s的25-40%。作者观察到,这种低效是由于GPU对不同thread blocks和warps工作分配不是最优的,造成了利用率低和不必要的共享内存读写。因此,本文提出了FlashAttention-2以解决这些问题。
虽然相比标准Attention,FlashAttention快了24倍,节约了1020倍内存,但是离设备理论最大throughput和flops还差了很多。本文提出了FlashAttention-2,它具有更好的并行性和工作分区。实验结果显示,FlashAttention-2在正向传递中实现了约2倍的速度提升,达到了理论最大吞吐量的73%,在反向传递中达到了理论最大吞吐量的63%。在每个A100 GPU上的训练速度可达到225 TFLOPs/s。
本文主要贡献和创新点为:

减少了non-matmul FLOPs的数量(消除了原先频繁rescale)。虽然non-matmul FLOPs仅占总FLOPs的一小部分,但它们的执行时间较长,这是因为GPU有专用的矩阵乘法计算单元,其吞吐量高达非矩阵乘法吞吐量的16倍。因此,减少non-matmul FLOPs并尽可能多地执行matmul FLOPs非常重要。
提出了在序列长度维度上并行化。该方法在输入序列很长(此时batch size通常很小)的情况下增加了GPU利用率。即使对于单个head,也在不同的thread block之间进行并行计算。
在一个attention计算块内,将工作分配在一个thread block的不同warp上,以减少通信和共享内存读/写。

在这里插入图片描述
Causal masking是attention的一个常见操作,特别是在自回归语言建模中,需要对注意力矩阵S应用因果掩码(即任何S ,其中 > 的条目都设置为−∞)。
由于FlashAttention和FlashAttention-2已经通过块操作来实现,对于所有列索引都大于行索引的块(大约占总块数的一半),我们可以跳过该块的计算。这比没有应用因果掩码的注意力计算速度提高了1.7-1.8倍。
不需要对那些行索引严格小于列索引的块应用因果掩码。这意味着对于每一行,我们只需要对1个块应用因果掩码。

并行处理
FlashAttention在batch和heads两个维度上进行了并行化:使用一个thread block来处理一个attention head,总共需要thread block的数量等于batch size × number of heads。每个block被调到到一个SM上运行,例如A100 GPU上有108个SMs。当block数量很大时(例如≥80),这种调度方式是高效的,因为几乎可以有效利用GPU上所有计算资源。
但是在处理长序列输入时,由于内存限制,通常会减小batch size和head数量,这样并行化成都就降低了。因此,FlashAttention-2还在序列长度这一维度上进行并行化,显著提升了计算速度。此外,当batch size和head数量较小时,在序列长度上增加并行性有助于提高GPU占用率。
Forward pass. FlashAttention算法有两个循环,K,V在外循环,Q,O在内循环。FlashAttention-2将Q移到了外循环i,K,V移到了内循环j,由于改进了算法使得warps之间不再需要相互通信去处理Qi,所以外循环可以放在不同的thread block上。这个交换的优化方法是由Phil Tillet在Triton[17]提出并实现的。

5 Transformer-VQ,线性transformer

论文:Transformer-VQ: Linear-Time Transformers via Vector Quantization
代码:Github:https://github.com/transformer-vq/transformer_vq

简单来说,Transformer-VQ就是对Attention的Key向量序列进行了“聚类”,并用所属类的类别中心近似原向量,然后Attention的复杂度就变成线性了。也就是说,Transformer-VQ仅仅改变了Key的形似,其余部分(理论上)完全不变,所以这是一种对Attention改动非常小的线性化方案,也能非常清楚体现出线性化后损失的精度在哪里(即用类别中心近似原向量的差距)。

5.1 VQ 一下

VQ全称是“Vector Quantize”,可以翻译为“向量量子化”或者“向量量化”,是指将无限、连续的编码向量映射为有限、离散的整数数字的一种技术。如果我们将VQ应用在自编码器的中间层,那么可以在压缩输入大小的同时,让编码结果成为一个离散的整数序列。
这里的“VQ”就是指VQ-VAE中的VQ

原有的attention:softmax(QK)V
在这里插入图片描述
其中C∈ℝc×dk,是训练参数,也是VQ的编码表(Codebook)。经过VQ之后,最直接的表现就是K的每个向量都变成了C中与之最相近的那个,这意味着K̂ 的每个向量都是C的向量之一,用数学的语言就是说K∈ℝn×dk变成了K̂ ∈Cn。

5.2 encoder

当然,直接按照式(2)去实现Transformer-VQ的话,复杂度还是二次的,但由于K̂ 的每个向量都是C
的向量之一,所以我们可以先算exp(QC⊤),然后从中“挑出”exp(QK̂ ⊤)对应的结果,而由于C
的大小是固定的,所以关键运算QC⊤的复杂度是线性的,这就是Transformer-VQ能线性化的原理(我们不妨称为“挑出”技巧)。

在这里插入图片描述

5.3 decoder

在这里插入图片描述

5.3 局域增强

当序列长度n远大于编码表大小c时,由抽屉原理我们知道部分编码向量必然会反复出现,甚至可以合理猜测所有编码向量应该会均匀分布在整个序列中。这样一来,邻近token的Attention就会跟远处某些token的Attention一样,也就是说模型无法区分远近,这本质上就是所有Kernelized Attention都存在的低秩问题。

已有的经验告诉我们,对于语言模型来说,相对于远处的token的来说邻近的token往往更为重要,所以一个好的语言模型架构应该具有区分远近的能力。为此,Transformer-VQ选择在QK̂ 之后,加上一个Sliding Window形状的Attention Bias(记为B),来对邻近token进行加权,如下图:
在这里插入图片描述
从最后一个图可以看出,如果将Window大小直接设为block大小l,即i<j或者i−j≤l时Bi,j=0,那么在分block计算时,矩阵B顶多影响最邻近的两个block,再远的block依旧可以用“挑出”技巧来线性化。为了便于下面的推导,我们记B[i,j]=B[il:(i+1)l,jl:(j+1)l],那么
在这里插入图片描述
笔者认为,B的引入是Transformer-VQ是跟其他Kernelized Attention拉开差距的关键,为了减少参数量且支持变长生成,我们约束B的非零部分为“Toeplitz矩阵”,即Bi,j是i−j的函数,此时B就相当于加性相对位置编码。除了这种做法外,也可以考虑换为笔者之前提出的ReRoPE,它是旋转位置编码的窗口版,跟B具有同样的相对位置编码形状。

5.4 梯度回传

了解VQ-VAE的读者都知道,“K̂ 的每个向量都是C的向量之一”只是前向传播的表现,反向传播用的可是原始的K,这意味着即便不同位置的K̂ j等于同一个Ck,但它们的梯度却不相等,这叫做STE(Straight-Through Estimator)。由于STE的存在,“挑出”技巧理论上仅可用于推理阶段,训练阶段是无法线性化的。

没有其他办法了吗?确实如此,如果我们坚持要获得精确的梯度结果,那么并没有线性化效率的方案。然而,考虑到VQ的梯度本身就是近似的,所以Attention获取精确的梯度似乎也没多大必要。于是作者想了个折衷的方案:依然是按照式(10)进行递归计算,仅在前两项使用STE(Key序列可以获得梯度),而Ui−1的梯度直接停掉(stop_gradient算子)。这样我们就保持了模型的线性性,同时也已经保留了最重要的梯度(邻近的两个block),算是一个比较合理的近似方案。从这一点来看Transformer-VQ跟Transformer-XL很像,Transformer-XL在递归的同时也停掉了历史窗口的梯度,即历史窗口可以参与递归计算,不传递梯度。

解决了梯度回传问题之后,在自回归交叉熵损失的基础上,再上VQ带来的用来更新编码表的辅助loss,就得到完整的训练目标了。当然,对于编码表的更新,Transformer-VQ采用了直接滑动平均的方案,所以只补充了Key的辅助loss,这些细节读者在熟悉VQ-VAE之后,稍微看一下原论文就理解了。

5.5 实验结果

值得指出的是,作者做VQ的基础架构并不是常规的MHA(Multi-Head Attention),而是笔者一直很推崇的GAU(Gated Attention Unit)+Softmax,Transformer-VQ更准确的命名应该是“GAU-VQ”,不了解GAU的读者可以参考链接。简单来说,GAU本身比MHA有着更高的效率,配合上VQ技巧后,就更加“如虎添翼”了。

实验方面,作者做了语言模型(ENWIK8、PG-19)和图像生成(IMAGENET64),所有的实验中的编码表大小都是c=512。模型最大参数量为1.3B,虽然比不上主流的大模型参数量,但其实对于科研来说不算小了。实验结果总体来说算得上优异:
在这里插入图片描述
最后,让人惊奇的是,Transformer-VQ的作者只有一个,并且身份是“Independent Researcher”。

5.6 发散思考

首先,“只需VQ一下Key,Transformer的复杂度就会变成线性”这个发现实在太美妙了,它实现了标准Attention到线性Attention的自然过渡,并且可以通过加Attention Bias的方式让它比很多的Kernelized Attention都有效。然后,通过VQ进行“聚类”的方式,也比Linformer、Nyströmformer等更为高明,因为它防止了未来信息的泄漏,可以自然地用来做Causal的语言模型。

我们知道,VQ本质上也是将序列转为离散id的运算,这跟Tokenizer的作用是非常相似的。从这个角度来看,Transformer-VQ跟MegaByte等模型一样,都是将Tokenizer内置在模型之中,并且相比MegaByte,VQ这一操作跟我们传统意义上的Tokenizer更为相似、直观。所以,Transformer-VQ实际上非常适合用来训练直接以Bytes输入的“No Tokenizer”模型,事实上,上述ENWIK8实验就是Bytes输入,Transformer-VQ效果明显优于MegaByte。

相比近来出的RetNet,Transformer-VQ没有显式的远程衰减,所以Long Context能力有可能会更好,同时由于Key经过了VQ,都是有限集合之一,所以不会出现没有学过的Key,因此长度外推能力大概率也会更好。虽然Transformer-VQ的基础架构GAU只是Single-Head的,但它在递归过程中模型记忆状态大小是Δ⊤iVi∈ℝc×dv,在默认的设置中,这比Multi-Head的RetNet还大(RetNet的记忆状态大小是nd2k,默认设置下dv=2ndk),因此,记忆容量理论上是足够的。

《简单得令人尴尬的FSQ:“四舍五入”超越了VQ-VAE》,可能会有读者想知道可否用更简单的FSQ取代VQ?笔者认为比较难,原因其实在上一篇文章给出了:第一,c=512还属于VQ优于FSQ的编码数量范围,所以换FSQ大概率会掉效果;第二,由于每层Attention的Key都要被VQ,所以平均来说VQ的Encoder和Decoder都不强,这种情况VQ近似精度更高,FSQ更适合Decoder和Decoder都足够强的场景;第三,Transformer-VQ需要用的是Key被VQ之后的中心向量而不是id,而FSQ则直接得到id,反而不容易恢复为近似的中心向量。
除此之外,用VQ而不是FSQ,使得Transformer-VQ有希望从现有的预训练模型如LLAMA2中微调过来,而不单单是从零训练。因为VQ具有鲜明的几何意义,跟K-Means有诸多相通之处,我们可以从现有预训练模型出发,选取一些样本计算出Key,对Key进行K-Means得到中心向量作为编码表的初始化,然后在原模型基础上加上VQ进行微调。不过Transformer-VQ不大好适配RoPE,所以要如前面所说,RoPE的模型要换成ReRoPE再VQ比较好,此时就可以不用加Bias了。

转载于:https://zhuanlan.zhihu.com/p/645376942

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值