想要讲清楚 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 ,我们要从以下几个方面入手:
①从图形入手,讲清楚 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 在干什么;
②从公式推导,讲清楚 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 怎么算;
③从代码讲解,讲清楚 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 应该怎么用;
那么我们先讲第一个部分:
Part 1. Multi Head Attention在干什么
想要理解 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 ,我们先理解 𝑠𝑒𝑙𝑓- 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 ,因为 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 就是 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 中的 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 等于1的情况。
为了说清楚 𝑠𝑒𝑙𝑓 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 大概在干什么,我们可以先回顾一下 𝑅𝑁𝑁 的结构。
我们可以通过 𝑅𝑁𝑁 的结构非常清楚地看到:
1.前面的信息需要不断地通过 𝑓𝑒𝑒𝑑 𝑓𝑜𝑟𝑤𝑎𝑑 才能传递到后面
2.信息流是单向的,后面时刻的信息包含了前面时刻的信息,但是前面的时刻的信息无法包含后面时刻
此时,我们根据上图归纳出了关于 𝑅𝑁𝑁 的两个特征,那么我们现在来看 𝑠𝑒𝑙𝑓 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 的结构:
此时我们发现, 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 完美的解决了 𝑅𝑁𝑁 结构下会产生的问题。基于 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 的计算方式,我们可以将所有时刻的向量同时输入到 𝑠𝑒𝑙𝑓 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 模块当中。也就是说,我们可以让不同时刻的输入互相“看到彼此”,从而让模型看到全局信息。
Part 2. Multi Head Attention怎么算
到这里我们已经说完了 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 的作用了,那么接下来我们开始讲 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 的具体计算流程:
1.我们在将单词输入到模型之前,先通过 𝑤𝑜𝑟𝑑 𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 得到该句子中每个词的词向量,同时再通过 𝑃𝑜𝑠𝑖𝑡𝑖𝑜𝑛𝑎𝑙 𝐸𝑛𝑐𝑜𝑑𝑖𝑛𝑔 得到所有词的位置向量,将其相加。这时我们就得到了一个字的完整的向量表示,并将第 𝑡 个字记为 𝑥𝑡 。
2.下一步我们就会通过三个矩阵将我们的映射为 𝑞𝑡 , 𝑣𝑡 ,和 𝑘𝑡 ,在得到三个矩阵之后我们将原来的 𝑥𝑡 按照论文进行运算,最终得到我们的结果:
那么这里的
q
t
q_{t}
qt ,
v
t
v_{t}
vt,和
k
t
k_{t}
kt 是怎么得到的呢?
我们这里定义三个矩阵 W Q W_{Q} WQ, W V W_{V} WV 和 W K W_{K} WK, 我们使用这三个矩阵分别对所有的字向量做三次线性变换,就得到了我们刚才提到的三个向量。
那我们为什么要对 映射三次呢?这里我们在参考了一些文章和博客之后,对三种映射的作用进行说明。
首先,
Q
Q
Q,
K
K
K,
V
V
V物理意义上是一样的,都是由同一个句子中的不同token组成的矩阵。其中矩阵的行数为token个数,列数为Word Embedding
维度。假设一个句子的“Hello,how are you?”长度是6,embedding维度是300,那么
Q
Q
Q,
K
K
K,
V
V
V的维度都是 6 * 300。
我们接下来先对
Q
a
Q_{a}
Qa和
K
b
K_{b}
Kb进行说明,这里我们用下标a和b分别表示token a
和token b
。我们从上面的公式可以知道,我们首先要做的一个运算就是
Q
a
Q_{a}
Qa*
K
b
K_{b}
Kb,那么这个乘积运算,我们可以这样理解:
Q
a
Q_{a}
Qa 和
K
b
K_{b}
Kb 是将 token 投影到
d
k
d_{k}
dk维空间的不同表示。因此,我们可以将这些投影的点积视为衡量 token 投影之间相似性的一个指标。对于通过
Q
a
Q_{a}
Qa投影的每个向量,其与通过
K
b
K_{b}
Kb投影的向量之间的点积衡量了这些向量之间的相似性。如果我们将
v
i
v_{i}
vi和
u
j
u_{j}
uj称为第
i
i
i个通过
Q
a
Q_{a}
Qa投影得到的token的向量和第
j
j
j个通过
K
b
K_{b}
Kb投影得到的token的向量,那么它们的点积可以被视为:
上述公式是对余弦相似度的变形
所以我们可以看到,
Q
a
Q_{a}
Qa和
K
b
K_{b}
Kb之间的乘积直接反映了
Q
a
Q_{a}
Qa和
K
b
K_{b}
Kb 之间的相关性。如果此时我们的输入是“Hello,how are you?”那么我们通过QK之间的乘积,其实可以知道"Hello"这个单词与这句话中所有的文本的相关性,对所有的文本都重复这个过程,我们也就也得到相关性分数的矩阵。(这里直接展示Softmax之后的结果)
softmax之后的相关性矩阵
好,接下来我们再解释第三个向量:
V
V
V向量。如上图所示,我们通过
Q
Q
Q和
V
V
V两个向量可以得到任意两个token之间的相关性矩阵,但此时的相似性矩阵很难再去体现原来的
x
t
x_{t}
xt向量中的语义信息。而此时的
V
V
V向量仍然能够表示原来的句子,所以我们拿相似性矩阵去和
V
V
V相乘,就可以得到一个加权的结果。这一步相当于对 向量中的信息做一步提纯,因为
V
V
V中本来表示的就是不同token的word embedding
向量,再乘上attention score
之后,每个token都会对其它token做出基于权重的关注度调整,让每个单词关注到该关注的另一个单词。
Part 3. Multi Head Attention应该怎么用
图片来源于attention is all you need文章
这里我们一步一步写一个Multi Head Attention。同时为了方便大家理解,我们也不写成类的形式,而是一行行地为大家实现Multi Head Attention。
我们这里使用的输入样例来自于我们之前的文章(word and positional embedding)。在这篇文章当中,大家只需要知道我们的输入样例的形状如下所示:
input_x = x_batch_embedding + positional_encoding
print(input_x.shape)
##
torch.Size([4, 4, 64])
#第一个维度的4代表四个批次,第二个维度的4代表Context_length,第三个维度的64代表embedding维度
接下来我们去定义三个线性层作为 , , 三个矩阵,并且将多头的数量定义为4。
num_heads = 8 #头的个数
Wq = nn.Linear(d_model, d_model)
Wk = nn.Linear(d_model, d_model)
Wv = nn.Linear(d_model, d_model)
Wo = nn.Linear(d_model, d_model)
Q = Wq(input_x)
K = Wk(input_x)
V = Wv(input_x)
print(Q.shape)
##
torch.Size([4, 4, 64])
#第一个维度的4代表四个批次,第二个维度的4代表Context_length,第三个维度的64代表embedding维度
接下来我们对QKV三个向量按照注意力头的数量去reshape,把embedding的维度按照多头的数量去平分。
Q = Q.reshape(batch, context_length, num_heads, d_model // num_heads).permute(0, 2, 1, 3)
K = K.reshape(batch, context_length, num_heads, d_model // num_heads).permute(0, 2, 1, 3)
V = V.reshape(batch, context_length, num_heads, d_model // num_heads).permute(0, 2, 1, 3)
print(Q.shape)
##
torch.Size([4, 8, 4, 8]) #第一个维度代表batch,第二个维度代表头的数量,第三个维度是context_length,第四个维度是每个头上分配到的维度
接下来我们先完成 的点积得到attention matrix,接着对相关性矩阵做softmax,再将得到的结果和 向量相乘。在完成上述过程之后,我们将头拼接起来,并传入最后的全连接层,得到最终结果。
attn_score = Q @ K.transpose(-2, -1) / (d_model ** 0.5)
print(attn_score.shape)
attn_score = F.softmax(attn_score, dim = -1)
attn_score = attn_score @ V attn_score = attn_score.permute(0, 2, 1, 3).reshape(batch, context_length, d_model) print(attn_score.shape)
attn_score = Wo(attn_score)
print(attn_score.shape)
##
torch.Size([4, 8, 4, 4]) #八个头,每个头都对应一个相关性矩阵
torch.Size([4, 4, 64]) #将头拼接起来 torch.Size([4, 4, 64])
至此,我们基本完成了一个 s e l f self self- a t t e n t i o n attention attention的操作。
但是这里我们继续深入地探讨一个问题—— W Q W_{Q} WQ, W V W_{V} WV 和 W K W_{K} WK这三个矩阵的维度必须要等于 d m o d e l d_{model} dmodel吗?
首先我们探讨 W Q W_{Q} WQ, W K W_{K} WK这两个矩阵,因为从之前的公式,我们可以发现相似度矩阵就是由这两个矩阵构成。
我们在表示 W Q W_{Q} WQ, W K W_{K} WK两个矩阵时,分别用了两个全连接层,并且全连接层的两个参数都为 d m o d e l d_{model} dmodel 。 关于第一个参数,我们可以确定一定为 d m o d e l d_{model} dmodel ,因为需要和输入向量的最后一个维度对齐。
但是第二个参数一定要为
d
m
o
d
e
l
d_{model}
dmodel 吗?答案是否定的,我们这里假设第二个维度为
d
r
a
n
d
o
m
d_{random}
drandom。那我们的
Q
Q
Q和
K
K
K向量形状就变为 4 * 8 * 4 *$d_{random}$ // 8
。此时我们在进行
Q
Q
Q *
K
T
K^{T}
KT 时,发现得到的结果依然会等于4 * 8 * 4 * 4
。所以从这里我们可以得出结论,
Q
Q
Q和
K
K
K的维度不需要等于
d
r
a
n
d
o
m
d_{random}
drandom ,但是这两个向量的维度必须相等。
接下来,我们讨论
W
V
W_{V}
WV这个矩阵,依旧根据上方的公式,我们发现
W
V
W_{V}
WV矩阵是要和相似度矩阵相乘的,所以全连接层的一个维度也必须是
d
r
a
n
d
o
m
d_{random}
drandom ,第二个维度我们也假设为一个随机的维度
d
r
a
n
d
o
m
2
d_{random2}
drandom2 。此时attention@V得到的结果形状为 4 * 8 * 4 *$d_{random2}$ // 8
。那么这样的形状会影响接下来的运算吗?其实并不会,因为我们最终还会通过一个
W
O
W_{O}
WO的矩阵,此时我们只需要把
W
O
W_{O}
WO的矩阵的第一个维度和
d
r
a
n
d
o
m
2
d_{random2}
drandom2 // 8保持相同就可以。
所以总结一下: Q Q Q和 K K K的维度不需要等于 d m o d e l d_{model} dmodel ,但是一定要相同。 V V V 的维度也不一定需要等于 d m o d e l d_{model} dmodel ,但是要保证最后的全连接层和 V V V的维度相同。