我们详细分析了 𝑀𝑢𝑙𝑡𝑖 𝑆𝑒𝑙𝑓 - 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 ,但是其实看过论文的同学都会有一个问题——Transformer中的Decoder部分不是还有个掩码的操作吗,那是个什么玩意儿?
图片来源于Attention is all you need论文
今天这篇文章就是告诉大家:
-
Decoder部分的掩码是个什么玩意儿
-
怎么用简单的代码去实现这个掩码
我们先来讲第一个问题:
Decoder部分的掩码是个什么玩意儿,为什么要有这个掩码?
首先,我们先要有一个概念。这个概念就是传统的Seq2Seq中Decoder使用的是基于RNN的模型,因此我们给模型输入的始终都是当前时刻的词 𝑥𝑡 (这里我们用 𝑡 来表示当前时刻),模型不论怎么样都不会看到未来的输入,即 𝑥𝑡+1 。但是由于Transformer的Decoder使用的是 𝑆𝑒𝑙𝑓 - 𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 ,导致在Decoder层中所有的输入信息都暴露出来,这显示是不对的。此时,我们就需要对Decoder做一些处理,而这样的处理就被称之为Mask。
举个例子,Decoder 的 ground truth 为 "<start> I am fine",我们将这个句子输入到 Decoder 中,经过 WordEmbedding 和 Positional Encoding 之后,将得到的矩阵做三次线性变换。然后进行 self-attention 操作,首先通过得到相关性矩阵,接下来非常关键,我们要对相关性矩阵进行 Mask,举个例子,当我们输入 "I" 时,模型目前仅知道包括 "I" 在内之前所有字的信息,即 "<start>" 和 "I" 的信息,不应该让其知道 "I" 之后词的信息。道理很简单,我们做预测的时候是按照顺序一个字一个字的预测,怎么能这个字都没预测完,就已经知道后面字的信息了呢?Mask 非常简单,首先生成一个下三角全 0,上三角全为负无穷的矩阵,然后将其与 Scaled Scores 相加即可
之后再做 softmax,就能将 - inf 变为 0,得到的这个矩阵即为每个字之间的权重:
好,概念已经梳理完毕,那我们接下来就说第二个问题:
怎么用简单的代码去实现这个掩码?
首先, 通过上面的概念可知,我们需要用一个context_length * context_length 形状的矩阵作为mask矩阵。这里我们使用之前论文中出现的例子(之前的例子可访问这篇文章):
input_x = x_batch_embedding + positional_encoding
print(input_x.shape)
num_heads = 4
Wq = nn.Linear(d_model, d_model)
Wk = nn.Linear(d_model, d_model)
Wv = nn.Linear(d_model, d_model)
Wo = nn.Linear(d_model, d_model)
Q = Wq(input_x)
K = Wk(input_x)
V = Wv(input_x)
Q = Q.reshape(batch, context_length, num_heads, d_model // num_heads).permute(0, 2, 1, 3)
K = K.reshape(batch, context_length, num_heads, d_model // num_heads).permute(0, 2, 1, 3)
V = V.reshape(batch, context_length, num_heads, d_model // num_heads).permute(0, 2, 1, 3)
print(Q.shape)
attn_score = Q @ K.transpose(-2, -1) / (d_model ** 0.5)
attn_score.shape
##
torch.Size([4, 4, 64])
torch.Size([4, 4, 4, 16])
torch.Size([4, 4, 4, 4])
通过代码,我们发现此时的相关性矩阵的形状为4 * 4 * 4 * 4。其中第一个维度是我们初始设置的批次大小,第二个维度是头的数量,第三个维度和第四个维度都代表着context_length。那么此时为了对这样一个 4 * 4的相似性矩阵做mask,我们也需要构造一个4 * 4的矩阵。
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
attn_score = attn_score.masked_fill(mask, -float("inf"))
attn_score[0][0]
##
tensor([[ 0.4195, -inf, -inf, -inf], [ 0.1423, -0.4002, -inf, -inf], [ 0.0469, -0.1568, -0.2440, -inf], [ 0.2692, -0.6866, -0.4085, 0.0288]], grad_fn=<SelectBackward0>)
这里我们用torch.triu函数来帮助构造相关的mask矩阵。观察结果,成功mask了 𝑥𝑡+1 时刻的信息。
参考文章: