「ArXiv2020」【Efficient Transformers: A Survey】论文笔记
Paper:https://arxiv.org/abs/2009.06732
Abstract
这篇工作主要总结了对于改进Transformer模型计算和存储有效性(computational and memory efficiency)的各类"X-former"工作,例如:Reformer、Linformer、Performer、Longformer等。提供了对各领域现有高效Transformer模型的一个有组织的广泛的总结和概述.
1. Introduction
self-attention机制是Transformer模型的关键定义特征。这个机制可以看做一个类似图的感应偏置(graph-like inductive bias),将一个序列当中所有tokens和一个基于相关性的pooling运算相连接。对于self-attention的一个疑虑就是其二次时间和存储复杂度,其阻碍了模型在许多设置下的可扩展性(scalability)。为了解决这一问题,许多模型的变体近期被提出,本文称这些工作为"efficient Transformers".
在不同语境下对于efficiency的不同. 可以是内存的占用或计算的消耗(例如预测和训练需要的FLOPs). 本文对于efficiency的定义为对于large input下的计算和内存复杂度.
对于长序列进行建模时,高效的self-attention模型是十分关键的。例如对于图像、文件、视频等,其大多由相对较大数量的像素或tokens组成。因此高效处理长序列对于Transformer模型的广泛应用是十分关键的.
这篇survey主要关注为了解决self-attention机制问题所提出的对于Transformer模型建模的进展以及模型结构的革新(modeling advances and architecture innovation),同时也简短的讨论了一些一般性的改进以及其他高效性的改进。
这篇工作对高效Transformer模型进行了分类,根据技术的革新以及主要用例对不同工作进行了characterizing。特别的,文章回顾了Transformer在语言以及视觉领域的应用,建立了不同模型的关系。
2. Background on Transformers
这一节主要回顾了Transformer模型的原文,下面也稍微复习一下。
Transformer模型是一个多层的模型结构,由许多的Transformer blocks堆叠而成。Transformer blocks的特点就是multi-head self-attention机制+position-wise 前馈网络+layer normalization模块+残差连接。Transformer模型的输入一般是
R
B
×
R
N
\mathbb{R}^B\times \mathbb{R}^N
RB×RN的tensor,其中
B
B
B是batch size,
N
N
N是序列长度。
输入首先经过embedding层,将每一个one-hot的token表示转换为
d
d
d维的embedding,例如
R
B
×
R
N
×
R
D
\mathbb{R}^B\times \mathbb{R}^N\times \mathbb{R}^D
RB×RN×RD。而后新的tensor和位置编码相加组合,然后经过multi-head self-attention模块。其中,位置编码结构可以是原文的正弦输入,也可以是可训练的embeddings。
multi-head self-attention模块的输入和输出由残差模块和layer normalization层相连接。multi-head self-attention模块的输出而后经过两层前馈网络(也是残差连接+layer norm)。则sub-layer的残差连接+layer norm可以表示为:
X
=
LayerNorm
(
F
S
(
X
)
)
+
X
X = \textrm{LayerNorm}(F_S(X))+X
X=LayerNorm(FS(X))+X
其中,
F
S
(
X
)
F_S(X)
FS(X)是sub-layer模块,其可以是multi-head self-attention或者position-wise前馈网络。
2.1 Multi-Head Self-Attention
multi-head self-attention机制的核心是学一个alignment,其中序列中的每一个元素从其他token中收集。(The key idea behind the mechanism is to learn an alignment in which each element in the sequence learns to gather from other tokens in the sequence.【个人理解就是学一个互相之间的相关性】)。对于一个head的运算可以定义为:
A
h
=
Softmax
(
α
Q
h
K
h
T
)
V
h
A_h=\textrm{Softmax}(\alpha Q_hK_h^T)V_h
Ah=Softmax(αQhKhT)Vh其中,
Q
h
=
W
q
X
Q_h=\textbf{W}_qX
Qh=WqX,
K
h
=
W
k
X
K_h=\textbf{W}_kX
Kh=WkX,
V
h
=
W
v
X
V_h=\textbf{W}_vX
Vh=WvX是对于输入序列时间维度的线性变换。
W
q
,
W
,
W
v
∈
R
d
×
d
N
h
\textbf{W}_q,\textbf{W},\textbf{W}_v\in \mathbb{R}^{d\times \frac{d}{N_h}}
Wq,W,Wv∈Rd×Nhd是query、key和value映射的权值矩阵,将输入
X
X
X映射为
d
d
d维的输出tensor。
N
h
N_h
Nh是head数目。
X
X
X是一个
R
N
×
R
d
\mathbb{R}^N\times \mathbb{R}^d
RN×Rd维的矩阵,
α
\alpha
α是一个尺度因子,一般设为
1
d
\frac{1}{\sqrt{d}}
d1。设head的数量为
N
H
N_H
NH,那么heads
A
1
⋯
A
N
H
A_1\cdots A_{N_H}
A1⋯ANH的输出被concatenated到一起,并经过一个稠密layer。输出
Y
Y
Y则可以被表示为
Y
=
W
O
[
A
1
⋯
A
N
H
]
Y=\textbf{W}_O[A_1\cdots A_{N_H}]
Y=WO[A1⋯ANH],其中
W
O
\textbf{W}_O
WO是输出的线性映射。注意
A
A
A的计算一般考虑由tensor
R
B
×
R
N
×
R
N
h
×
R
d
×
d
N
h
\mathbb{R}^{B} \times \mathbb{R}^{N}\times \mathbb{R}^{N_h}\times \mathbb{R}^{d\times \frac{d}{N_h}}
RB×RN×RNh×Rd×Nhd并行计算,且所有head的线性变换并线计算。
attention矩阵
A
=
Q
K
T
A=QK^T
A=QKT主要负责学习序列中token之间的alignment scores。在这个公式中,在Q和K中的每一个元素/token之间的点积被取出,从而推动了self-attention中的self-alignment过程,从而使得token学会gather from each other。(In this formulation, the dot product between each element/token in the query (Q) and key (K) is taken. This drives the self-alignment process in self-attention whereby tokens learn to gather from each other.)
On the scalability of Self-Attention
从上文的公式中可以看出,计算attention矩阵的计算和内存复杂度是输入序列长度的平方( N 2 N^2 N2)。特别是 Q K T QK^T QKT矩阵相乘本身消耗了 N 2 N^2 N2的时间和内存。这限制了self-attention模型在需要处理长序列时的应用。后文将介绍解决这一问题的方法。
2.2 Position-wise Feed-forward Layers
self-attention模块的输出而后输入两层的前馈网络,激活为ReLU。每一个position的前馈层运算是独立的,因此称为position-wise。公式为:
F
2
(
R
e
L
U
(
F
1
(
X
A
)
)
)
F_2(ReLU(F_1(X_A)))
F2(ReLU(F1(XA)))其中
F
1
,
F
2
F_1,F_2
F1,F2为前馈函数,形式为
W
x
+
b
Wx+b
Wx+b。
2.3 Putting it all together
每一个Transformer block公式如下:
X
A
=
LayerNorm
(
MultiheadSelfAttention
(
X
)
)
+
X
X
B
=
LayerNorm
(
PositionFFN
(
X
A
)
)
+
X
A
X_A=\textrm{LayerNorm}(\textrm{MultiheadSelfAttention}(X))+X\\ X_B=\textrm{LayerNorm}(\textrm{PositionFFN}(X_A))+X_A
XA=LayerNorm(MultiheadSelfAttention(X))+XXB=LayerNorm(PositionFFN(XA))+XA其中
X
X
X是Transformer block的input,
X
B
X_B
XB是Transformer block的output。
2.4 Transformer Mode
Transformer模型根据使用的方式不同可以主要分为三个mode:
(1)encoder-only(例如:分类任务)
(2)decoder-only(例如:语言模型)
(3)encoder-decoder(例如:机器翻译)
对于encoder-decoder mode,一般包含多个multi-headed self-attention模块,其中包括encoder和decoder中的标准self-attention模块以及一个encoder-decoder的cross-attention让decoder来利用encoder的信息。这影响了self-attention机制的设计。
对于encoder mode,对self-attention机制是否需要是causal的是没有限制的,即完全依赖于过去和现在的token。
对于encoder-decoder的设置,encoder和encoder-decoder的cross attention可以是non-causal的,但是decoder的self-attention必须是causal的。
支持causal auto-regressive decoding的能力是设计高效self-attention机制是需要的,因为这是在许多应用中的一个限制因素。
3. A Survey of Efficient Transformer Models
3.1 A Taxonomy of Efficient Transformers
本章根据模型的核心方法和主要用途将高效Transformer进行分类。除了基于segment-based recurrence模型外大部分模型的主要目标是近似二次cost的attention矩阵(approximate the quadratic-cost attention matrix)。每一个方法将一些稀疏的方法用于原本稠密的attention机制。
-
Fixed Patterns(FP) 最早对于self-attention的改动就是简单的通过限制field of view为固定值,预定义patterns(例如local windows以及固定步长的block patterns)对attention矩阵进行稀疏化,。
- Blockwise Patterns 这一方法在实际中最简单的例子就是blockwise(或chunking)范例,其将输入序列分为固定的block来考虑局部感受野(local receptive fields)。这样的工作包括Blockwise(Qiu et al., 2019)以及Local Attention(Parmar et al., 2018)。将输入序列分为blocks可以将复杂度从 N 2 N^2 N2减少为 B 2 B^2 B2且 B < < < N B<<<N B<<<N,从而很大程度的减少cost。这样的方法是许多复杂方法的基础。
- Strided Patterns 另一种方法是去考虑strided attention patterns,即按固定间隔attending。这类工作包括Sparse Transformer(Child et al., 2019)以及Longformer(Beltagy et al., 2020)使用strided或dilated(膨胀)式的视窗。
- Compressed Patterns 另一种方法是使用一些pooling运算去下采样输入序列长度为固定的pattern。例如,Compressed Attention(Liu et al., 2018)使用strided卷积来有效的减少序列长度。
-
Combination of Patterns(CP) 此类方法的核心是通过组合两个或更多的不同访问模式(distinct access patterns)去提高覆盖面(coverage)。例如,Sparse Transformer(Child et al., 2019)通过一半的heads分配给pattern将strided和local attention组合。相似的,Axial Transformer(Ho et al., 2019)在给定高维input张量的前提下,沿着input张量的每一个single axis应用一系列的self-attention计算。本质上讲,基于组合patterns的方法和fixed patterns的方法减少内存复杂度的方式相同。区别是多patterns的组合提高了self-attention机制的总体覆盖面。(In essence, the combination of patterns reduces memory complexity in the same way that fixed patterns does. The difference, however, is that the aggregation and combinaton of multiple patterns improves the overall coverage of the self-attention mechanism.)
-
Learnable Patterns(LP) 可学习的patterns是对fixed(pre-determined)patterns的一种延伸。使用可学习的patterns建模是为了以数据驱动的方式学习access pattern。这一类patterns的关键是去确定token的相关性,将token分配给buckets或者clusters。Reformer(Kitaev et al., 2020)提出了一种基于哈希相似度的度量来高效的将token聚类为chunks。类似的,Routing Transformer(Roy et al 2020)将线上的 k k k-means聚类用于tokens。Sinkhorn Sorting Network(Tay et al., 2020b)通过学习输入序列的block排序来暴露(expose)attention权重的sparsity。LP模式的关键仍是利用fixed pattern。然而,这一类学习去对input token的排序/聚类可以得到一个序列的更优全局view同时保留fixed patterns的优势。
-
Memory 另一种方式是利用一个side memory模块来同时访问多个token。一种通用的形式是global memory来访问整个序列。global token作为一种从输入序列token中学习gather的memory形式。这种方法第一次提出于Set Transformer(Lee et al., 2019)称为 inducing points 方法。这些参数通常理解为"memory",且用于处理临时上下文信息。这可以被看做一种参数attention(Sukhbaatar et al., 2019)。Global memory被用于ETC(Ainslie et al., 2020)以及Longformer(Beltagy et al., 2020)。使用有限的memory(inducing points),我们可以对输入序列进行类似pooling的运算进行压缩,这是在设计高效self-attention模块时可以用的一个trick。
-
Low-Rank Methods 另一种新兴的技术是通过低秩近似self-attention矩阵来提高效率。核心的想法是假设 N × N N\times N N×N矩阵的低秩结构。Linformer(Wang et al., 2020b)是这一类技术的经典例子。其将keys换个values的length维度映射为低维表示( N → k N\to k N→k)。因此self-attention矩阵从 N × N N\times N N×N降为 N × k N\times k N×k,从而改善了内存复杂度的问题。
-
Kernels 另一种近期比较火的方法是去通过kernelization来view attention机制。 kernel方法(Katharopoulos et al., 2020; Choromanski et al., 2020)可以更巧妙的对self-attention机制进行数学重写,来避免对 N × N N\times N N×N矩阵的显式计算。因为核方法是对attention矩阵的近似,因此其也可以看做low-rank方法的一种(Choromanski et al., 2020)。
-
Recurrence 对blockwise方法的一种直接拓展是通过recurrence来连接这些blocks。Transformer-XL(Dai et al., 2019)提出了segment-level recurrence机制来连接多个segments和blocks。这些模型在某些方面可以看做fixed patterns。然而,由于其不同于其他block/local方法,我们将其单独归类。
以上不同类别并没有明确的边界,不同工作直接也有类别的交叉。
3.2 Detailed Walk-through of Efficient Transformer Models
这一节对efficient Transformer的几篇代表工作做了详细的讨论。
结构如下:
- 首先是基于local and fixed pattern的一些工作,包括:Memory Compressed Transformer(Liu et al., 2018)以及Image Transformer(Parmar et al., 2018);
- 然后是一篇早期的利用global memory的工作Set Transformer(Lee et al., 2019);
- 然后是使用combinations of patterns的工作,包括:Sparse Transformer(Child et al., 2019)以及Axial Transformer(Ho et al., 2019);
- 然后是Sparse Transformer方向内的memory-based的工作,包括:Longformer(Beltagy et al., 2019)以及ETC(Ainslie et al., 2020);
- 然后是基于incorporate(合并) learnable patterns(LP),包括:Routing Transformer(Roy et al., 2020),Reformer(Kitaev et al., 2020)以及Sinkhorn Transformer(Tay et al., 2020b);
- 然后是基于low-rank分解的工作,包括:Linformer(Wang et al., 2020b)以及Synthesizers(Tay et al., 2020a);
- 然后是基于kernel的方法,包括:Performer(Choromanski et al., 2020)以及Linear Transformers(Katharopoulos et al., 2020);
- 最后是基于segment-based recurrence的方法,包括Transformer-XL以及Compressive Transformers(Rae et al., 2020).
3.2.1 Memory Compressed Transformer
Memory Compressed Transformer (Liu et al., 2018)是最为了处理长序列来修改Transformer的工作之一,这篇工作主要包括两方面的修改:
1. Local Attention Span:将输入的长序列分为相似长度的blocks,然后在每一个block上分别做self-attention,这使得每一个block的attention的cost一直,从而使得激活的数量尺度与输入长度呈线性关系;
2. Memory-compressed Attention:通过使用一个strided卷积来减少key和value的数量,query保持不变。这会减小attention矩阵的size以及attention的计算量,压缩的因子依赖于卷积的kernel size和strides。Memory-compressed Attention使得原来基于全局输入序列信息的方式改为局部的attention。
计算和内存复杂度: 设block size是
b
b
b,那么self-attention的计算复杂度就是
O
(
b
2
)
\mathcal{O}(b^2)
O(b2),共
n
/
b
n/b
n/b个blocks,local attention的内存和计算复杂度就是
O
(
b
⋅
n
)
\mathcal{O}(b\cdot n)
O(b⋅n);Memory-compressed Attention的计算复杂度是
O
(
n
2
/
k
)
\mathcal{O}(n^2/k)
O(n2/k),其中
k
k
k是卷积核的size和stride。
3.2.2 Image Transformer
这篇工作受到卷积神经网络的启发,将self-attention的感受野限制到local neighborhoods。这样可以使得模型可以处理更大的batch size并保持likelihood loss tractable。另外可以作为一种归纳偏置。Image Transformer提供了encoder-decoder结构,其中encoder生成每一个输入像素通道(pixel-channel)的上下文的表示,decoder在每一个time step对每一个像素自动回归生成一个channel。
Localized Attention Span 限制感受野到local neighborhood解决了对于大输入计算全局self-attention所带来的内存和计算量的消耗,但是改变了每一个query位置的neighborhood可能会阻止将self-attention打包为两个矩阵相乘(?)。为了避免这个问题,Image Transformer将输入分块为:“query blocks”及其相关联的“memory blocks”,其中所有的queries来自单个query block,模型处理相同的memory block。有两种选择query block及其相关联的memory block neighborhood的方法:1-d local attention 和2-d local attention
xxxx
3.2.3 Set Transformer
这篇工作是为了解决一个称为set-input的问题,输出不能受到输入数据的顺序和大小的影响(个人认为类似于聚类问题)。这篇工作用到了Transformer的Attention结构来获取输入set中不同元素的关系。这里只关注与efficient相关的问题,也就是为了解决与大小无关的问题。工作中提出了一种方法来讲原来二次复杂度的full self-attention改进为 O ( n m ) \mathcal{O}(nm) O(nm),其中 m m m是固定的超参。