[外链图片转存中…(img-brIPkSME6%A8%A-1722930016763)
上个星期,Google出了篇论文,叫做《Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention》。
论文介绍了一种新的方法,可以将基于Transformer的大语言模型接收的上下文长度拓展到无限长。
乍听之下非常唬人,不过之前Google发布的Gemini 1.5大模型就支持超长100万token上下文长度,这篇新的Infini-attention论文一出,很多人认为Gemini背后用的就是这项技术。
Attention计算
Transformer这个模型结构,从出生这一天起,业界就开始解决它应对长上下文时计算量爆炸的问题了。
在之前的文章也简单提过这一计算量,这里再重新列一遍。
我们回到Attention的计算公式中,
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \\ Attention(Q,K,V)=softmax(dkQKT)V
其中Q、K、V都是由文本输入向量乘以对应权重矩阵产生,分别有:
$Q=W_qX \$$K=W_kX \$$V=W_vX \$
X的维度由输入长度和每个token的Embedding长度决定,即[seq_length, dim],三个权重矩阵的维度分别为[dim,dim]。
那么Q、K、V矩阵分别的维度都是[seq_length,dim]。
代入Attention计算公式的第一部分,
Q K T QK^T \\ QKT
这两个矩阵的相乘结果,会得到一个维度为[seq_length, seq_length]的矩阵。
如果上下文长度超长,即seq_length极其庞大,则这个矩阵的维度也是惊人的。
目前有不少工程化的方法来解决这一问题,其核心思想都是“分而治之”。利用softmax也能局部计算的特性,分解QK矩阵的计算。有兴趣可以看看之前写的文章。
从线性Transformer以及Transformer-XL说起
任何新技术的提出,都是有迹可循的。
要讨论Infini-Transformer,先得了解一下2019年Google提出的Transformer-XL以及2020年《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》这篇论文。
Transformers are RNNs:
这篇论文提出了一个线性Attention的假设。
我们知道,attention公式中很重要的一步是Q矩阵和K矩阵相乘后,进行softmax。QK相乘产生的矩阵大小是n * n,即空间复杂度是序列长度的平方。如果说,我们能够把softmax拿掉,
Q K ⊤ V \boldsymbol{Q}\boldsymbol{K}^{\top}\boldsymbol{V} \\ QK⊤V
就是简单的三个矩阵
Q ∈ R n × d k , K ∈ R m × d k , V ∈ R m × d v \boldsymbol{Q}\in\mathbb{R}^{n\times d_k}, \boldsymbol{K}\in\mathbb{R}^{m\times d_k}, \boldsymbol{V}\in\mathbb{R}^{m\times d_v} \\ Q∈Rn×dk,K∈Rm×dk,V∈Rm×dv
的相乘,而矩阵相乘满足结合律,我们可以先算KV,得到一个维度为[d,d]的矩阵,然后使用Q来左乘这个矩阵,因为
d ≪ n d \ll n \\ d≪n
所以复杂度可以降到O(n),即线性复杂度。
那么我们该如何做到这一点呢?
我们先将传统的带softmax的单个token的attention等价改写为以下形式:
Attention ( Q , K , V ) i = ∑ j = 1 n e q i ⊤ k j v j ∑ j = 1 n e q i ⊤ k j \text{Attention}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \boldsymbol{v}_j}{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}} Attention(Q,K,V)i=j=1∑neqi⊤kjj=1∑neqi⊤kjvj
对于序列中某一个token来说,它最终的attention值,是和序列中其它token的k值分别进行点积后并归一化后,使用softmax对所有的点积进行0到1的概率分布处理,然后再和每个token的v值相乘。最终将所有相乘的结果进行加和。
实际上softmax在这里,起到的就是一个输出q和每个k相似度的作用。
我们把原始softmax的
e q i ⊤ k j e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \\ eqi⊤kj
换成一个sim函数,即比较相似度的函数,改写为以下形式:
A t t e n t i o n ( Q , K , V ) i = ∑ j = 1 n sim ( q i , k j ) v j ∑ j = 1 n sim ( q i , k j ) Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\\ Attention(Q,K,V)i=j=1∑nsim(qi,kj)j=1∑nsim(qi,kj)vj
现在重点就是,找到一个合适的sim函数。而且需要满足softmax的性质,即
sim ( q i , k j ) ≥ 0 \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0 \\ sim(qi,kj)≥0
论文中使用的是核函数变换 ϕ ( x ) \phi(x) ϕ(x) 来模拟softmax,核函数将数据映射到一个更高维的空间(核空间)。
在这个空间中,一些原本在原始空间中线性不可分的问题可能变得线性可分。
比如上图中,红色的数和蓝色的数在一维空间中是不可分的,但是通过二次函数映射,就变得线性可分了。
这种映射本质上是一种特征扩展,它可以让线性模型在这个扩展后的特征空间中以线性方式表达原本的非线性关系。
在注意力机制的上下文中,核函数被用来变换查询和键的表示,这样就可以通过简单的点积来近似原本需要通过softmax计算的复杂非线性相似度。
于是可以把单个token的Attention A t t e n t i o n ( Q , K , V ) i = V i ′ Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i=V'_i Attention(Q,K,V)i=Vi′改写为以下形式:
V i ′ = ∑ j = 1 N ϕ ( Q i T ) ϕ ( K j ) V j ∑ j = 1 N ϕ ( Q i T ) ϕ ( K j ) V'_i = \frac{\sum_{j=1}^N \phi ({Q_i}^T) \phi ({K_j}) V_j} {\sum_{j=1}^N \phi ({Q_i}^T) \phi ({K_j})}\\ Vi′=∑j=1Nϕ(QiT)ϕ(Kj)∑j=1Nϕ(QiT)ϕ(Kj)Vj
然后因为计算的是在j维度上的和,所以可以把 Q i Q_i Qi提取出来,
V i ′ = ϕ ( Q i T ) ∑ j = 1 N ϕ ( K j ) V j T ϕ ( Q i T ) ∑ j = 1 N ϕ ( K j ) V'_i = \frac{\phi({Q_i}^T) \sum_{j=1}^N \phi({K_j}) V_j^T} {\phi({Q_i}^T) \sum_{j=1}^N \phi({K_j})} \\ Vi′=ϕ(QiT)∑j=1Nϕ(Kj)ϕ(QiT)∑j=1Nϕ(Kj)VjT
在这个形式下,每个token的q其实都是和一样的 ∑ j = 1 N ϕ ( K j ) V j T \sum_{j=1}^N \phi({K_j}) V_j^T ∑j=1Nϕ(Kj)VjT进行计算,所以这个求和只用计算一次。那么完整的 V 0 . . . V n V_0...V_n V0...Vn其实就是每一个Q和这个和相乘一次。那么复杂度就只和n相关,所以将复杂度降低到了 O ( n ) O(n) O(n)。
论文中具体的 ϕ ( x ) \phi(x) ϕ(x)选择的是 elu ( x ) + 1 \text{elu}(x) + 1 elu(x)+1。
ELU ( x ) = { x if x > 0 α ( e x − 1 ) if x ≤ 0 \text{ELU}(x) = \begin{cases} x & \text{if } x > 0 \\ \alpha (e^x - 1) & \text{if } x \leq 0 \end{cases}\\ ELU(x)={xα(ex−1)if x>0if x≤0
具体为什么选择ELU,论文里随便写了几句。这个不是重点。就不再展开了。
Causal Masking
我们知道,Transformer在执行生成任务时候,Attention的计算是使用掩码矩阵的形式。
即当前的token只和上文计算Attention,形成了一个上三角掩码矩阵。
那么,之前的公式中,加和的部分,则是由第一个位置j=1加和至当前位置i。
V
i
′
=
∑
j
=
1
i
ϕ
(
Q
i
T
)
ϕ
(
K
j
)
V
j
∑
j
=
1
i
ϕ
(
Q
i
T
)
ϕ
(
K
j
)
V'_i = \frac{\sum_{j=1}^i \phi ({Q_i}^T) \phi ({K_j}) V_j} {\sum_{j=1}^i \phi ({Q_i}^T) \phi ({K_j})}
Vi′=∑j=1iϕ(QiT)ϕ(Kj)∑j=1iϕ(QiT)ϕ(Kj)Vj
V
i
′
=
ϕ
(
Q
i
T
)
∑
j
=
1
i
ϕ
(
K
j
)
V
j
T
ϕ
(
Q
i
T
)
∑
j
=
1
i
ϕ
(
K
j
)
V'_i = \frac{\phi({Q_i}^T) \sum_{j=1}^i \phi({K_j}) V_j^T} {\phi({Q_i}^T) \sum_{j=1}^i \phi({K_j})}
Vi′=ϕ(QiT)∑j=1iϕ(Kj)ϕ(QiT)∑j=1iϕ(Kj)VjT
我们使用 S i S_i Si和 Z i Z_i Zi分别表示上式中求和的部分,则有
S i = ∑ j = 1 i ϕ ( K j ) V j T S_i = \sum_{j=1}^i \phi({K_j}) V_j^T Si=∑j=1iϕ(Kj)VjT Z i = ∑ j = 1 i ϕ ( K j ) Z_i = \sum_{j=1}^i \phi({K_j}) Zi=∑j=1iϕ(Kj)
那么,可将某个Token的Attention计算公式再简化为:
V i ′ = ϕ ( Q i T ) S i ϕ ( Q i T ) Z i V'_i = \frac{\phi({Q_i}^T) S_i} {\phi({Q_i}^T) Z_i} \\ Vi′=ϕ(QiT)Ziϕ(QiT)Si
也就是说,每推断一步, S i S_i Si和 Z i Z_i Zi都进行了一次更新,即重新从j到i进行了加和。
这个形式马上就要和循环神经网络,RNN相结合了。
我们都知道,RNN的形式是,给定模型一个输入,然后产生出一个hidden state,和一个输出y。
下一步将上一步的hidden state和当前这一步的输入再输入给模型,继续产生当前步的hidden state和输出y,循环往复。
那么Casual Masking形式的其实也是一样的步骤,S和Z分别表示为attention memory和归一化memory。
这两个值实际上和RNN中的Hidden State一样,存储了之前每一步的信息。
Transformer利用这两个值,和当前Q进行计算,产生新的输出。
s 0 = 0 , z 0 = 0 , s i = s i − 1 + ϕ ( x i W K ) ( x i W V ) T , z i = z i − 1 + ϕ ( x i W K ) , y i = f l ( ϕ ( x i W Q ) T s i ϕ ( x i W Q ) T z i + x i ) . \begin{align} s_0 &= 0, \\ z_0 &= 0, \\ s_i &= s_{i-1} + \phi(x_i W_K) \left(x_i W_V\right)^T, \\ z_i &= z_{i-1} + \phi(x_i W_K), \\ y_i &= f_l\left(\frac{\phi(x_i W_Q)^T s_i}{\phi(x_i W_Q)^T z_i} + x_i\right). \end{align} s0z0siziyi=0,=0,=si−1+ϕ(xiWK)(xiWV)T,=zi−1+ϕ(xiWK),=fl(ϕ(xiWQ)Tziϕ(xiWQ)Tsi+xi).
在 S i S_i Si和 Z i Z_i Zi加和的过程中,上一步加和的结果我们是知道的,于是只需要计算当前这一步的值,加到原值即可。计算量也大大降低了。
有了这层铺垫,我们也就不难理解,为什么Infini-Attention论文中,要把memory定义成这个形式:
A m e m = σ ( Q ) M s − 1 σ ( Q ) z s − 1 A_{mem} = \frac{\sigma({Q}) M_{s-1}} {{\sigma(Q)} z_{s-1}} \\ Amem=σ(Q)zs−1σ(Q)Ms−1
以及它更新状态的公式:
M s ← M s − 1 + σ ( K ) T V M_{s} \leftarrow M_{s-1} + \sigma(K)^T V Ms←Ms−1+σ(K)TV z s ← z s − 1 + ∑ t = 1 N σ ( K t ) z_{s} \leftarrow z_{s-1} + \sum_{t=1}^N \sigma(K_t) zs←zs−1+∑t=1Nσ(Kt)
Transformers-XL
Transformer-XL,是在模型结构上做了一些改进,以应对长上下文问题。
基于Transformer架构进行NLP建模,在处理文本时,输入长度是固定的。比如BERT的限制的输入大小是512,如果不足512,则使用padding标记填充长度至512。但是如果输入的文本长度超过512,就需要使用一些技巧来应对。
最简单的方式是将输入的文本按照512的长度进行分段,然后分段进行训练。
但是这会造成两个问题,
- 上下文碎片化,由于切分段落是根据长度切分,这个处理方式并不考虑文本中的真实语义边界。有可能将完整语义的一句话,切分至两个段落。
- 冗余推理,在推理过程中,需要按照512的窗口大小,一步步向后进行推理,这样实际上会造成一定的计算冗余,效率不高。
捕获每个segment上下文信息
针对上下文碎片化的问题,引入Segment-Level recurrence mechanism来建模更长序列,它通过融合前后两个Segment的信息来到这个目的。
简单来说就是,上一个片段的一些信息,会传递至下一个片段,这样保持了上文中有价值的信息能够传递下去,不至于让每个Transformer块获取的信息是孤立的。
具体实现方式如下: 假设序列长度为L,
当前的segment为,
s τ = [ x τ , 1 , x τ , 2 , . . . , x τ , L ] \text{s}_{\tau}=[x_{\tau,1},x_{\tau,2},...,x_{\tau,L}] \\ sτ=[xτ,1,xτ,2,...,xτ,L]
后面的segment为,
s τ + 1 = [ x τ + 1 , 1 , x τ + 1 , 2 , . . . , x τ + 1 , L ] \text{s}_{\tau+1}=[x_{\tau+1,1},x_{\tau+1,2},...,x_{\tau+1,L}] \\ sτ+1=[xτ+1,1,xτ+1,2,...,xτ+1,L]
当前segment计算得出的第 层的状态向量
h τ n ∈ R L × d h_{\tau}^n \in \mathbb{R}^{L \times d} \\ hτn∈RL×d
SG ( h τ n − 1 ) \text{SG}(h_{\tau}^{n-1}) SG(hτn−1)表示不使用梯度, [ SG ( h τ n − 1 ) ∘ h τ + 1 n − 1 ] \left[ \text{SG}(h_{\tau}^{n-1}) \; \circ \;h_{\tau+1}^{n-1} \right] [SG(hτn−1)∘hτ+1n−1] 表示将前后两个Segment的输出向量在序列维度上进行拼接。
h ~ τ + 1 n − 1 = [ SG ( h τ n − 1 ) ∘ h τ + 1 n − 1 ] \tilde{h}_{\tau+1}^{n-1} = \left[ \text{SG}(h_{\tau}^{n-1}) \; \circ \;h_{\tau+1}^{n-1} \right] \\ h~τ+1n−1=[SG(hτn−1)∘hτ+1n−1]
然后,下面的公式表示获取Self-Attention计算中相应的 , , 矩阵,其中在计算 的时候仅仅使用了当前Segment的向量,在计算 和 的时候同时使用前一个Segment和当前Segment的信息。
q τ + 1 n , k τ + 1 n , v τ + 1 n = h τ + 1 n − 1 W q T , h ~ τ + 1 n − 1 W k T , h ~ τ + 1 n − 1 W v T q_{\tau+1}^{n}, \; k_{\tau+1}^n, \; v_{\tau+1}^n = h_{\tau+1}^{n-1}W_{q}^{\mathrm{ T }}, \; \tilde{h}_{\tau+1}^{n-1}W_{k}^{\mathrm{ T }}, \; \tilde{h}_{\tau+1}^{n-1}W_{v}^{\mathrm{ T }} \\ qτ+1n,kτ+1n,vτ+1n=hτ+1n−1WqT,h~τ+1n−1WkT,h~τ+1n−1WvT
最后通过Self-Attention融合计算,得出当前Segment的输出向量序列。
h τ + 1 n = Transformer-Layer ( q τ + 1 n , k τ + 1 n , v τ + 1 n ) h_{\tau+1}^n = \text{Transformer-Layer}(q_{\tau+1}^{n}, \; k_{\tau+1}^n, \; v_{\tau+1}^n) \\ hτ+1n=Transformer-Layer(qτ+1n,kτ+1n,vτ+1n)
反正就是当前segment的每一层,都和之前segment的信息进行融合,
在推理过程中,Transfomer-XL的推理过程通过直接复用上一个片段的信息,不用进行重新计算,将推理过程由逐字推理,提升到以片段为单位进行推理,这种简化带来的速度提升是成百上千倍的。
Infini Attention
然而,Infini Attention表示,Transformer-XL这种玩法不过瘾。
从它们论文中的示意图来看,Transformer-XL能有效利用的,其实只是上一个segment的信息。再之前的segment信息,其实在逐段推理中,已经遗失的差不多了。
既然每一步推理的信息都要记录下来,我为啥不直接保存全部记忆,然后在推理过程中,看哪部分记忆对当前推理最有帮助,然后我把它捞出来就行。
从它的结构图中可以看到,Infini Attention包含两个部分,
左边的部分是历史保留的信息,即记忆。右边的部分是正常的Transformer。
在每一个segment的推理中,它使用当前的Q,去左边部分,查找哪部分记忆最有用,将其提取出来,然后拼接到正常Attention计算出的结果中,再进行接下来的步骤。
A = sigmoid ( β ) ⊙ A m e m + ( 1 − sigmoid ( β ) ) ⊙ A d o t . A = \textit{sigmoid} (\beta) \odot A_{mem} + (1 - \textit{sigmoid}(\beta)) \odot A_{dot}. \\ A=sigmoid(β)⊙Amem+(1−sigmoid(β))⊙Adot.
我们可以看到,当前的Attention和记忆Attention分别取多少,其实是由一个sigmoid激活函数决定的。这里又有点LSTM网络中门控的意思了。
记忆信息也是在推理过程中实时更新的,还记得之前我们讨论线性Attention的公式吗?
A m e m = σ ( Q ) M s − 1 σ ( Q ) z s − 1 A_{mem} = \frac{\sigma({Q}) M_{s-1}} {{\sigma(Q)} z_{s-1}} \\ Amem=σ(Q)zs−1σ(Q)Ms−1
以及它更新状态的公式:
M s ← M s − 1 + σ ( K ) T V M_{s} \leftarrow M_{s-1} + \sigma(K)^T V Ms←Ms−1+σ(K)TV z s ← z s − 1 + ∑ t = 1 N σ ( K t ) z_{s} \leftarrow z_{s-1} + \sum_{t=1}^N \sigma(K_t) zs←zs−1+∑t=1Nσ(Kt)
这个更新计算量也不大,只有增量部分是需要计算的。
总结
无限(Infini)上下文其实有些标题党博眼球的成分,
首先,这里更新记忆的方式并不是无损保留记忆,对于历史信息是有舍弃有保留的。这对于超长文本的推理,仍然可能造成一些遗忘前文关键信息的问题。
其次,结构中左边的记忆部分,所占空间也不小。有人戏称,这是在用一个大模型和一个知识库,在做实时RAG。
不过这个工作的亮点确实很多,因为这个记忆模块是可插拔模块,可以嵌入到任何大模型中。而且确实在计算量上,实现了线性拓展。
而且谷歌自己发布的实验评测,这个方式也实现了挺好的效果,在80亿(8B)参数这个量级上,进行了书籍总结的长文本实验,
其指标比其它方案优势也比较显著,
总之这个工作还是能给人耳目一新的感觉,很多技术细节也似曾相识,实属文艺复兴之作。
原文首发于