Transformer step by step--Multi Head Attention


想要讲清楚 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 ,我们要从以下几个方面入手:

①从图形入手,讲清楚 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 在干什么;

②从公式推导,讲清楚 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 怎么算;

③从代码讲解,讲清楚 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 应该怎么用;

那么我们先讲第一个部分:

Part 1. Multi Head Attention在干什么

想要理解 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 ,我们先理解 𝑠𝑒𝑙𝑓- 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 ,因为 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 就是 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 中的 𝑀𝑢𝑙𝑡𝑖 𝐻𝑒𝑎𝑑 等于1的情况。

为了说清楚 𝑠𝑒𝑙𝑓 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 大概在干什么,我们可以先回顾一下 𝑅𝑁𝑁 的结构。

李宏毅老师机器学习课程截图
我们可以通过 𝑅𝑁𝑁 的结构非常清楚地看到:

1.前面的信息需要不断地通过 𝑓𝑒𝑒𝑑 𝑓𝑜𝑟𝑤𝑎𝑑 才能传递到后面

2.信息流是单向的,后面时刻的信息包含了前面时刻的信息,但是前面的时刻的信息无法包含后面时刻

此时,我们根据上图归纳出了关于 𝑅𝑁𝑁 的两个特征,那么我们现在来看 𝑠𝑒𝑙𝑓 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 的结构:

在这里插入图片描述
此时我们发现, 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 完美的解决了 𝑅𝑁𝑁 结构下会产生的问题。基于 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 的计算方式,我们可以将所有时刻的向量同时输入到 𝑠𝑒𝑙𝑓 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 模块当中。也就是说,我们可以让不同时刻的输入互相“看到彼此”,从而让模型看到全局信息。

Part 2. Multi Head Attention怎么算

到这里我们已经说完了 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 的作用了,那么接下来我们开始讲 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 的具体计算流程:

1.我们在将单词输入到模型之前,先通过 𝑤𝑜𝑟𝑑 𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 得到该句子中每个词的词向量,同时再通过 𝑃𝑜𝑠𝑖𝑡𝑖𝑜𝑛𝑎𝑙 𝐸𝑛𝑐𝑜𝑑𝑖𝑛𝑔 得到所有词的位置向量,将其相加。这时我们就得到了一个字的完整的向量表示,并将第 𝑡 个字记为 𝑥𝑡 。

2.下一步我们就会通过三个矩阵将我们的映射为 𝑞𝑡 , 𝑣𝑡 ,和 𝑘𝑡 ,在得到三个矩阵之后我们将原来的 𝑥𝑡 按照论文进行运算,最终得到我们的结果:
图片来源于Attention is all you need论文
那么这里的 q t q_{t} qt , v t v_{t} vt,和 k t k_{t} kt 是怎么得到的呢?

我们这里定义三个矩阵 W Q W_{Q} WQ, W V W_{V} WV W K W_{K} WK, 我们使用这三个矩阵分别对所有的字向量做三次线性变换,就得到了我们刚才提到的三个向量。

那我们为什么要对 映射三次呢?这里我们在参考了一些文章和博客之后,对三种映射的作用进行说明。

首先, Q Q Q, K K K, V V V物理意义上是一样的,都是由同一个句子中的不同token组成的矩阵。其中矩阵的行数为token个数,列数为Word Embedding维度。假设一个句子的“Hello,how are you?”长度是6,embedding维度是300,那么 Q Q Q, K K K, V V V的维度都是 6 * 300。

我们接下来先对 Q a Q_{a} Qa K b K_{b} Kb进行说明,这里我们用下标a和b分别表示token atoken b。我们从上面的公式可以知道,我们首先要做的一个运算就是 Q a Q_{a} Qa* K b K_{b} Kb,那么这个乘积运算,我们可以这样理解: Q a Q_{a} Qa K b K_{b} Kb 是将 token 投影到 d k d_{k} dk维空间的不同表示。因此,我们可以将这些投影的点积视为衡量 token 投影之间相似性的一个指标。对于通过 Q a Q_{a} Qa投影的每个向量,其与通过 K b K_{b} Kb投影的向量之间的点积衡量了这些向量之间的相似性。如果我们将 v i v_{i} vi u j u_{j} uj称为第 i i i个通过 Q a Q_{a} Qa投影得到的token的向量和第 j j j个通过 K b K_{b} Kb投影得到的token的向量,那么它们的点积可以被视为:
在这里插入图片描述
上述公式是对余弦相似度的变形

所以我们可以看到, Q a Q_{a} Qa K b K_{b} Kb之间的乘积直接反映了 Q a Q_{a} Qa K b K_{b} Kb 之间的相关性。如果此时我们的输入是“Hello,how are you?”那么我们通过QK之间的乘积,其实可以知道"Hello"这个单词与这句话中所有的文本的相关性,对所有的文本都重复这个过程,我们也就也得到相关性分数的矩阵。(这里直接展示Softmax之后的结果)
在这里插入图片描述
softmax之后的相关性矩阵

好,接下来我们再解释第三个向量: V V V向量。如上图所示,我们通过 Q Q Q V V V两个向量可以得到任意两个token之间的相关性矩阵,但此时的相似性矩阵很难再去体现原来的 x t x_{t} xt向量中的语义信息。而此时的 V V V向量仍然能够表示原来的句子,所以我们拿相似性矩阵去和 V V V相乘,就可以得到一个加权的结果。这一步相当于对 向量中的信息做一步提纯,因为 V V V中本来表示的就是不同token的word embedding向量,再乘上attention score之后,每个token都会对其它token做出基于权重的关注度调整,让每个单词关注到该关注的另一个单词。

Part 3. Multi Head Attention应该怎么用

图片来源于attention is all you need文章

这里我们一步一步写一个Multi Head Attention。同时为了方便大家理解,我们也不写成类的形式,而是一行行地为大家实现Multi Head Attention。

我们这里使用的输入样例来自于我们之前的文章(word and positional embedding)。在这篇文章当中,大家只需要知道我们的输入样例的形状如下所示:

input_x = x_batch_embedding + positional_encoding 
print(input_x.shape) 
## 
torch.Size([4, 4, 64]) 
#第一个维度的4代表四个批次,第二个维度的4代表Context_length,第三个维度的64代表embedding维度

接下来我们去定义三个线性层作为 , , 三个矩阵,并且将多头的数量定义为4。

num_heads = 8 #头的个数
Wq = nn.Linear(d_model, d_model)
Wk = nn.Linear(d_model, d_model) 
Wv = nn.Linear(d_model, d_model) 
Wo = nn.Linear(d_model, d_model) 
Q = Wq(input_x) 
K = Wk(input_x) 
V = Wv(input_x)
print(Q.shape) 
## 
torch.Size([4, 4, 64])
 #第一个维度的4代表四个批次,第二个维度的4代表Context_length,第三个维度的64代表embedding维度

接下来我们对QKV三个向量按照注意力头的数量去reshape,把embedding的维度按照多头的数量去平分。

Q = Q.reshape(batch, context_length, num_heads, d_model // num_heads).permute(0, 2, 1, 3) 
K = K.reshape(batch, context_length, num_heads, d_model // num_heads).permute(0, 2, 1, 3) 
V = V.reshape(batch, context_length, num_heads, d_model // num_heads).permute(0, 2, 1, 3) 
print(Q.shape) 
## 
torch.Size([4, 8, 4, 8]) #第一个维度代表batch,第二个维度代表头的数量,第三个维度是context_length,第四个维度是每个头上分配到的维度

接下来我们先完成 的点积得到attention matrix,接着对相关性矩阵做softmax,再将得到的结果和 向量相乘。在完成上述过程之后,我们将头拼接起来,并传入最后的全连接层,得到最终结果。

attn_score = Q @ K.transpose(-2, -1) / (d_model ** 0.5) 
print(attn_score.shape) 
attn_score = F.softmax(attn_score, dim = -1) 
attn_score = attn_score @ V attn_score = attn_score.permute(0, 2, 1, 3).reshape(batch, context_length, d_model) print(attn_score.shape) 
attn_score = Wo(attn_score) 
print(attn_score.shape) 
## 
torch.Size([4, 8, 4, 4]) #八个头,每个头都对应一个相关性矩阵 
torch.Size([4, 4, 64]) #将头拼接起来 torch.Size([4, 4, 64])

至此,我们基本完成了一个 s e l f self self- a t t e n t i o n attention attention的操作。

但是这里我们继续深入地探讨一个问题—— W Q W_{Q} WQ, W V W_{V} WV W K W_{K} WK这三个矩阵的维度必须要等于 d m o d e l d_{model} dmodel吗?

首先我们探讨 W Q W_{Q} WQ, W K W_{K} WK这两个矩阵,因为从之前的公式,我们可以发现相似度矩阵就是由这两个矩阵构成。

我们在表示 W Q W_{Q} WQ, W K W_{K} WK两个矩阵时,分别用了两个全连接层,并且全连接层的两个参数都为 d m o d e l d_{model} dmodel 。 关于第一个参数,我们可以确定一定为 d m o d e l d_{model} dmodel ,因为需要和输入向量的最后一个维度对齐。

但是第二个参数一定要为 d m o d e l d_{model} dmodel 吗?答案是否定的,我们这里假设第二个维度为 d r a n d o m d_{random} drandom。那我们的 Q Q Q K K K向量形状就变为 4 * 8 * 4 *$d_{random}$ // 8。此时我们在进行 Q Q Q * K T K^{T} KT 时,发现得到的结果依然会等于4 * 8 * 4 * 4所以从这里我们可以得出结论, Q Q Q K K K的维度不需要等于 d r a n d o m d_{random} drandom ,但是这两个向量的维度必须相等。

接下来,我们讨论 W V W_{V} WV这个矩阵,依旧根据上方的公式,我们发现 W V W_{V} WV矩阵是要和相似度矩阵相乘的,所以全连接层的一个维度也必须是 d r a n d o m d_{random} drandom ,第二个维度我们也假设为一个随机的维度 d r a n d o m 2 d_{random2} drandom2 。此时attention@V得到的结果形状为 4 * 8 * 4 *$d_{random2}$ // 8 。那么这样的形状会影响接下来的运算吗?其实并不会,因为我们最终还会通过一个 W O W_{O} WO的矩阵,此时我们只需要把 W O W_{O} WO的矩阵的第一个维度和 d r a n d o m 2 d_{random2} drandom2 // 8保持相同就可以。

所以总结一下: Q Q Q K K K的维度不需要等于 d m o d e l d_{model} dmodel ,但是一定要相同。 V V V 的维度也不一定需要等于 d m o d e l d_{model} dmodel ,但是要保证最后的全连接层和 V V V的维度相同。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

会震pop的码农

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值