【NeurIPS 2021】Luna: Linear Unified Nested Attention 线性统一嵌套注意力

#NeurIPS 2021# #Transformer#

今天分享的是NeurIPS 2021的一篇论文《Luna: Linear Unified Nested Attention》。该文章针对Transformer时间和空间复杂度高的问题,提出了一个线性统一嵌套注意力机制(Luna),实现了与各种强大的基准模型相当,甚至更好的性能。

原文标题:Luna: Linear Unified Nested Attention
作者信息:Xuezhe Ma, Xiang Kong, Sinong Wang, Chunting Zhou, Jonathan May, Hao Ma, Luke Zettlemoyer
发表会议:35th Conference on Neural Information Processing Systems (NeurIPS 2021), Sydney, Australia.
论文链接:https://arxiv.org/abs/2106.01540

摘要

该文章针对Transformer时间和空间复杂度高的问题,提出了一个线性统一嵌套注意力机制(Luna)。引入了一个额外的固定长度的序列作为输入,并产生一个额外的相应输出,使用两个嵌套注意力函数来近似Transformer中的常规softmax注意力,从而实现线性的时间和空间复杂度。与各种强大的基准模型相比,Luna实现了具有竞争力,甚至更好的性能。

1.介绍

Transformer自2017年提出,在机器翻译、语言理解、图像识别、生物信息学等广泛的语言和视觉任务上表现出了良好的效果。然而,Transformer的时间和空间复杂度都是输入句子长度的平方,这种平方的时间、空间复杂度使得Transformer难以建模很长的序列。

因此,许多人都在研究如何提高Transformer模型的时间和内存效率,一些研究者已经针对Transformer模型做了改进,能够针对长序列降低时间和空间复杂度,但对于中等长度的序列,它们的效率提升不高,并且准确度落后于Transformer。

基于此,作者提出了一个线性统一嵌套注意力机制(Luna)。

2.模型框架

在这里插入图片描述
上图中左边为一个 Transformer 编码器层的架构,右边为一个 Luna 编码器层的架构。

2.1 Pack and Unpack Attention

Luna架构的关键思想是:将传统 Transformer里的注意力解耦成两个嵌套的注意力操作,这两个操作都具有线性效率。为达到目的,引入了一个额外的输入,一个具有固定长度的序列。

(1) Pack Attention

将这个额外的输入作为查询序列,Luna使用它的第一个注意力,名为pack attention。该注意力的作用是,将context sequence(上下文序列)打包成一个固定长度的序列。

序列 P ∈ R l × d P\in\mathbb{R}^{l×d} PRl×d表示具有固定长度的额外输入序列。pack attention首先使用 P P P,将 C C C(上下文序列)打包为 Y p Y_p Yp
Y P = A t t n ( P , C ) ( 1 ) Y_P=Attn(P,C) (1) YP=Attn(P,C)1
其中, C ∈ R m × d C\in\mathbb{R}^{m×d} CRm×d Y p ∈ R l × d Y_p\in\mathbb{R}^{l×d} YpRl×d。因为 P P P的长度是一个常数 l l l,所以pack attention的复杂度是 O ( l m ) O(lm) O(lm),它相对于m是线性的。

(2) Unpack Attention

为了将序列解包,回到原始的查询序列 X X X的长度,Luna使用了第二个注意力,名为unpack attention。

Y X = A t t n ( X , Y p ) ( 2 ) Y_X=Attn(X,Y_p) (2) YX=Attn(X,Yp)2

其中 X X X是原始查询序列,与pack attention类似, unpack attention的复杂度是 O ( l n ) O(ln) O(ln),它相对于n是线性的。

(3) 额外输入序列

下一个问题是额外输入序列 P P P如何产生。一个简单的方法是将 P P P制定为每个Luna层的可学习参数。但直接采用该方法的缺点是, P P P不会捕获任何上下文信息。所以作者将 Y P Y_P YP制定为每个Luna层的附加输出。

Y X , Y P = L u n a A t t n ( X , P , C ) ( 3 ) Y_X,Y_P=LunaAttn(X,P,C) (3) YX,YP=LunaAttn(X,P,C)3

其中, 分别采用公式(1)和公式(2)计算 Y P Y_P YP Y X Y_X YX

而通过叠加多层Luna attention,捕获了来自 C C C序列的上下文信息的来自上一层的输出 Y P Y_P YP,会被用作下一层的输入 P P P。对于Luna的第一层,作者用可学习的位置嵌入生成 P P P

2.2 Luna Layers

类似Transformer层的定义,将LunaAttn和FFN,LayerNorm结合起来,可以得到Luna层的定义:

Y X , Y P = L u n a A t t n ( X , P , C ) X A , P A = L a y e r N o r m ( Y X + X ) , L a y e r N o r m ( Y P + P ) ( 4 ) X ′ , P ′ = L a y e r N o r m ( F F N ( X A ) + X A ) , P A \begin{gather*} Y_X,Y_P=LunaAttn(X,P,C)\\ X_A,P_A=LayerNorm(Y_X+X),LayerNorm(Y_P+P) (4)\\ X^{'},P^{'}=LayerNorm(FFN(X_A)+X_A),P_A \end{gather*} YX,YP=LunaAttn(X,P,C)XA,PA=LayerNorm(YX+X),LayerNorm(YP+P)4X,P=LayerNorm(FFN(XA)+XA),PA

则, X ′ X^{'} X P ′ P^{'} P即为Luna层的输出。

2.3 Luna Causal Attention

注意力机制在计算时应该只利用当前token以及当前token以前的token的信息,而不应该利用当前token之后的信息。

由于Luna Attention使用Pack Attention将输入序列压缩成了另一个长度更短的序列,因此不能像标准的self-attention一样直接mask掉之后的token,所以作者设计了Causal Attention模块。

作者首先假设 P P P不包含 X X X的信息,并且定义了causal函数:

f : R n × d 1 × R n × d 1 × R n × d 2 → R n × d 2 F ≜ f ( X , Y , Z ) , w h e r e   F t = 1 t X t ∑ j = 1 t Y j T Z j ( 5 ) \begin{gather*} f:\mathbb{R}^{n×{d_1}}×\mathbb{R}^{n×{d_1}}×\mathbb{R}^{n×{d_2}}→\mathbb{R}^{n×{d_2}} \\ F\triangleq f(X,Y,Z),where \ F_t= \frac{1}{t}X_t\sum_{j=1}^{t}Y_j^TZ_j (5) \end{gather*} f:Rn×d1×Rn×d1×Rn×d2Rn×d2Ff(X,Y,Z),where Ft=t1Xtj=1tYjTZj5

F t F_t Ft表示矩阵 F F F的第 t t t行。从 F F F的的定义可以看到,矩阵 F F F的第 t t t行表示了输入 X X X Y Y Y Z Z Z的第 t t t行以及第 t t t行以前的token的信息。

有了上面的定义,可以通过如下的步骤实现Causal Attention:

首先计算pack attention: A p a c k = ω ( P X T / d ) A_{pack}=\omega(PX^T/\sqrt{d}) Apack=ω(PXT/d ),此处未对 ω \omega ω使用softmax函数,因为softmax中的归一化项会将X的未来信息泄漏到历史中。受Linear Transformer启发,作者在此处将激活函数定义为 ω ( . ) = e l u ( . ) + 1 \omega(.)=elu(.)+1 ω(.)=elu(.)+1

接着使用causal函数计算unpack attention: A u n p a c k = ω ( f ( X , X , A p a c k T ) ) A_{unpack}=\omega(f(X,X,A_{pack}^T)) Aunpack=ω(f(X,X,ApackT))

最后的输出 Y Y Y可以表示为: Y = f ( A u n p a c k , A p a c k T , X ) Y=f(A_{unpack},A_{pack}^T,X) Y=f(Aunpack,ApackT,X)

3.实验

① 长上下文序列建模
在这里插入图片描述
表1 列出了 LRA 基准上各种模型的结果,Luna 在所有任务上都取得了比较良好的结果,并且平均准确度显著优于其他基线方法。

在这里插入图片描述
表2 表示针对不同输入长度的字节级文本分类任务,不同模型的训练速度和内存消耗峰值。

② 机器翻译
在这里插入图片描述
表4 显示了 Luna 在 WMT’14 EN→DE 测试集上的BLEU分数。

③ 用于大规模预训练的掩码语言建模
在这里插入图片描述
通过表6可以看出,在较小的数据集(16GB)上,Luna模型与其他预训练语言模型相比,具有相似或稍好的下游结果,在更大的数据集(160GB)上,Luna的性能比采用普通Transformer架构的RoBERTa稍差。

4.总结

① 提出了Luna模型:一个简单、高效、有效的线性注意力机制,替代常规的Softmax注意力。

② 通过引入具有固定长度的额外输入,Luna能够捕获足够的上下文信息,同时线性地执行注意操作。

③ 在三个序列建模任务——长上下文序列建模、机器翻译、用于大规模预训练的掩码语言建模任务上,Luna实现了与各种强大的基准模型相当,甚至更好的性能。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值