模型介绍
现如今NLP领域的预训练模型实在是太大了,从最开始的显存装不下,到内存装不下,再到如今的硬盘装不下,让多少人望而却步,大模型就非得这么耗存储吗?有没有优化手段呢?针对长序列的Transformer训练问题,Reformer给出了一种存储极致压缩的方案。
Reformer主要涉及了四处创新点:
- 使用Axial Positional Embedding来减小位置编码矩阵
- 提出了基于局部敏感性哈希算法(Locality Sensitive Hashing, LSH)的self-attention机制
- 提出可以使用分段的方式来处理全连接层(FFN)的计算过程,每次只计算一部分,从而不需要将整个序列加载到内存
- 使用可逆(Reversible)残差连接来代替传统残差连接,好处是在前向过程中前 层的中间值和输出都不需要存储了,只保留最后一层的输出
模型改进
Axial Positional Embedding
Axial Positional Embedding(APE)可用于输入是多维数组(Tensor)的场景,不过好在NLP的输入基本都是一维的序列,理解起来更容易些,我们假设位置向量维度是
d
d
d,输入序列的长度是
L
L
L。并且
L
=
m
∗
n
L=m*n
L=m∗n,我们可以将序列排列成矩阵
P
P
P,如下图所示,则矩阵元素
P
i
,
j
P_{i,j}
Pi,j对应原始序列中第
(
i
−
1
)
∗
m
+
j
(i-1)*m+j
(i−1)∗m+j个位置(注:
i
i
i和
j
j
j 的下标都是从1开始)。
我们为矩阵
P
P
P的每一行都创建一个向量
r
i
r_i
ri,同时也为每一列都创建一个向量
c
i
c_i
ci,向量维度分别是
d
1
d_1
d1、
d
2
d_2
d2,这样
P
i
j
P_{ij}
Pij就关联了两个向量
(
r
i
,
c
i
)
(r_i,c_i)
(ri,ci),如果我们再令
d
1
+
d
2
=
d
d_1+d_2=d
d1+d2=d,就可以让
c
o
n
c
a
t
(
r
i
,
c
i
)
concat(r_i,c_i)
concat(ri,ci)表示
P
i
j
P_{ij}
Pij的位置向量,维度是
d
d
d,这就是一维场景下的APE。
我们再来分析下参数量,传统位置向量矩阵大小是 L ∗ d L*d L∗d,使用APE方法会得到两个矩阵,大小分别是 m ∗ d 1 m*d_1 m∗d1、 m ∗ d 2 m*d_2 m∗d2,我们再看下enwik8-64K(序列长度是64K)的例子, L = 10000 , d = 512 L=10000,d=512 L=10000,d=512,如果 m = 253 , n = 253 , d 1 = 256 , d 2 = 256 m=253,n=253,d_1=256,d_2=256 m=253,n=253,d1=256,d2=256,参数量差不多从3千万减少到了13万!
分段FFN
全连接层一直是神经网络中的参数量大户,Transformer也不例外,
W
1
W_1
W1和
W
2
W_2
W2大小均是
d
m
o
d
e
l
∗
d
f
f
d_{model}*d_{ff}
dmodel∗dff,注意Transformer中的FFNN全称是Position-wise Feed-Forward Networks,重点就是这个position-wise,区别于普通的全连接网络,这里FFN的输入是序列中每个位置上的元素,而不是整个序列,所以每个元素完全可以独立计算,最极端节省内存的做法是遍历序列,每次只取一个元素得到FFN的结果,但是这样做时间消耗太大,“分段”的含义就是做下折中,将序列分成N段,也就是N个子序列,每次读取一个子序列进行FFN计算,最后将N份的结果拼接:
分段FFN只是一种计算上的技巧,计算结果和原始FFN完全一致,所以不会影响到模型效果,好处是不需要一次性将整个序列
(
b
a
t
c
h
_
s
i
z
e
,
L
,
d
m
o
d
e
l
)
(batch\_size,L,d_{model})
(batch_size,L,dmodel)读入内存,劣势当然是会增加额外的时间开销了。
可逆(Reversible)残差连接
可逆残差网络它最主要的特点是每一层的激活值都可以通过下一层的激活值计算得到,因此不需要保持大量的后向传播的参数,从而达到节约显存的目的。
可逆神经网络按照通道将神经网络输入 x x x 分成两部分,表示为 x 1 x_1 x1和 x 2 x_2 x2, F F F和 G G G结果相同,就是普通的残差块,可逆残差块的输出是 c o n c a t ( y 1 , y 2 ) concat(y_1,y_2) concat(y1,y2)。
在反向过程中,
x
1
x_1
x1和
x
2
x_2
x2 可以通过以下方法得到:
下图是可逆残差块的结构:
Transformer中也含有残差连接,自然也可以改成可逆残差块:
做法和可逆残差网络完全相同:
Y
1
=
X
1
+
A
t
t
e
n
t
i
o
n
(
X
2
)
,
Y
2
=
X
2
+
F
e
e
d
F
o
r
w
a
r
d
(
Y
1
)
Y_1=X_1+Attention(X_2),Y_2=X_2+FeedForward(Y_1)
Y1=X1+Attention(X2),Y2=X2+FeedForward(Y1)。
Reformer如何分割得到 X 1 X_1 X1和 X 2 X_2 X2,在第一层之前,将输入copy了一份,也就是 X = X 1 = X 2 X=X_1=X_2 X=X1=X2。其余层的 X 1 X_1 X1与 X 2 X_2 X2则对应 Y 1 Y_1 Y1 和 Y 2 Y_2 Y2。在前向计算过程中,每一个残差块i输出 Y i , 1 Y_{i,1} Yi,1和 Y i , 2 Y_{i,2} Yi,2,再作为下一层输入计算得到下一层的 Y i + 1 , 1 Y_{i+1,1} Yi+1,1和 Y j + 1 , 2 Y_{j+1,2} Yj+1,2就可以删除 Y i , 1 Y_{i,1} Yi,1和 Y i , 2 Y_{i,2} Yi,2。这样,即使Reformer有 N层,在前向过程中我们也只需保留最后一层的 Y n , 1 Y_{n,1} Yn,1和 Y n , 2 Y_{n,2} Yn,2即可。
注:可逆残差块的结构和原始Transformer中的结构已经不相同了,原始Transformer是 Y 1 = X 1 + A t t e n t i o n ( X 1 ) , Y 2 = Y 1 + F e e d F o r w o r d ( Y 1 ) Y_1=X_1+Attention(X_1),Y_2=Y_1+FeedForword(Y_1) Y1=X1+Attention(X1),Y2=Y1+FeedForword(Y1),所以在反向过程中无法还原出 X 1 X_1 X1和 X 2 X_2 X2,也就是不可逆。
局部敏感性哈希Self-Attention
由于softmax先指数缩放再归一化的本质,使其极易受到极大值影响,特别是长序列,得到的分布几乎总是稀疏的,这种稀疏性分布是有现实含义的:序列中的某个元素一般只会和少数几个元素具有较高的相似性/关联性。
一种通用的思路就是如果我们能为每个query q i q_i qi找到最相似的 K K K个key,只和它们计算点积,再取softmax,只要寻找相似key集合速度够快,并且 K << L,这种方式不论是时间还是空间都是有优势的。那么问题来了,如何快速找到和 q i q_i qi 最相似的key呢?
这个问题可以转化为向量检索问题,向量检索又可以分为精确(exact)检索和近似(approximate)检索两类,后者在学术界被称为**approximated nearest neighbor search(ANN)**问题,在工业界有非常广泛的应用,毕竟能提供检索服务的场景数据量一般都非常大,如何快速召回就显得格外重要,ANN问题的解决方案包括局部敏感性哈希(Locality Sensitive Hashing, LSH)、树方法和Product Quantization。
Reformer选择的就是LSH,啥是LSH呢?提起哈希函数,大家应该都不陌生,局部敏感性哈希函数是一类特殊的哈希函数,特殊在“局部敏感”上面,如果一个哈希函数具有如下的特点,则它就属于局部敏感性哈希函数:
如果有三个点 x 1 , x 2 , x 3 x_1,x_2,x_3 x1,x2,x3,如果 x 1 x_1 x1和 x 2 x_2 x2相邻, x 1 x_1 x1和 x 3 x_3 x3相隔较远,同时 h a s h ( x 1 ) hash(x_1) hash(x1)和 h a s h ( x 2 ) hash(x_2) hash(x2) 碰撞的概率要比 h a s h ( x 1 ) hash(x_1) hash(x1) 和 h a s h ( x 3 ) hash(x_3) hash(x3) 碰撞的概率大得多,我们就说 h a s h ( ) hash() hash()函数属于LSH,也就是它保留了原始数据之间的距离属性。
LSH有很多种,Reformer里用到的是基于随机投影(random projections)的哈希方法,实现非常简单,假设我们希望有 b 个哈希结果,也就是哈希函数对应 b 个桶,首先创建大小是 d m o d e l ∗ b 2 d_{model} * \frac b 2 dmodel∗2b 的矩阵 R R R,矩阵元素服从标准正态分布,向量 x x x哈希值等于 a r g m a x ( c o n c a t ( x R , − x R ) ) argmax(concat(xR,-xR)) argmax(concat(xR,−xR))。
随机投影是一种降维方法,最大的特点估计就是简单,构造一个随机矩阵 x ∗ W x*W x∗W, 完事了,这玩意真的有用?随机投影背后的数学原理是Johnson–Lindenstrauss lemma:
继续说回LSH self-attention,我们知道,在Transformer中,
Q
,
K
,
V
Q,K,V
Q,K,V 是通过
X
W
Q
,
X
W
K
,
X
W
V
XW_Q,XW_K,XW_V
XWQ,XWK,XWV 得到的,由于经过了不同的线性映射(乘以矩阵),即使是同一位置的
q
i
q_i
qi和
k
i
k_i
ki都很难保证哈希值相同,这还怎么找相似?Reformer的解决办法是让
Q
=
K
Q=K
Q=K,实验证明,这样做并不会影响模型效果,还有一个问题,当
Q
=
K
Q=K
Q=K时,
q
i
q_i
qi和
k
i
k_i
ki相同,
k
i
k_i
ki必然属于最相似的
K
K
K个key,并且
q
i
q_i
qi 和
k
i
(
q
i
)
k_i(q_i)
ki(qi) 的点积结果一定是最大的,同样会造成softmax稀疏,所以LSH self-attention会mask掉自己。
在得到每个 q i q_i qi 的哈希值后,如何快速得到每个 q i q_i qi 的最相似的key集合?也就是和 q i q_i qi 落入同一个桶内的 S e t q i = q j , j ≠ i Set_{q_i}=q_j,j \not =i Setqi=qj,j=i?Reformer选择了排序,也就是下图中的"Sort by LSH bucket",将落入同一个桶的 q q q 排列在一起,下面就可以进行点积计算了,这时候又有一个问题,有的桶元素个数多,有的桶元素个数少,如果以桶数据组成batch,明显各个batch的size不同,不便于批处理,那就分段(chunk)吧,假设序列长度 L L L,一共有 b b b个桶,则平均下来每个桶内有 L b \frac L b bL个元素,考虑到有的桶元素个数会大于 L b \frac L b bL,为了让同一个桶内的 q j q_j qj都属于 S e t q i Set_{q_i} Setqi ,肯定要让每一段序列长度大于 L b \frac L b bL,Reformer设置的子序列长度是 2 ∗ L b 2*\frac L b 2∗bL,平均一个子序列包含两个桶的数据,具体在计算点积时, q i q_i qi会在它所在的子序列和前一个子序列中找同一个桶内的 q j q_j qj 进行点积,然后计算softmax,得到 v i v_i vi 。
考虑到LSH毕竟有误差,有可能很相似的
q
i
q_i
qi和
q
j
q_j
qj 没有落入同一个桶内,那就多来几轮哈希,将每一轮哈希得到的
S
e
t
q
i
Set_{q_i}
Setqi 取并集作为最终的key集合。
模型参考
论文地址:https://arxiv.org/abs/2001.04451
代码地址:https://github.com/google/trax/tree/master/trax/models/reformer