Masked Multi-Head Attention(掩码多头注意力)是Transformer模型中的一个关键机制,广泛应用于各种自然语言处理任务,如机器翻译、文本生成等。它通过引入掩码和多头机制,提升了模型在处理序列数据时的灵活性和准确性。
Multi-Head Attention机制
在介绍Masked Multi-Head Attention之前,先了解一下基本的Multi-Head Attention机制。Multi-Head Attention机制通过并行的多个注意力头(attention heads)来捕捉不同的上下文信息,提高模型的表现能力。具体来说:
-
输入嵌入(Input Embeddings):假设输入序列为 {𝑥1,𝑥2,…,𝑥𝑛}{x1,x2,…,xn}。
-
线性变换:将输入嵌入通过线性变换得到查询(Query)、键(Key)和值(Value)矩阵:
𝑄=𝑋𝑊𝑄,𝐾=𝑋𝑊𝐾,𝑉=𝑋𝑊𝑉Q=XWQ,K=XWK,V=XWV
其中 𝑊𝑄WQ、𝑊𝐾WK 和 𝑊𝑉WV 是可学习的权重矩阵。
-
计算注意力权重:对于每个头,计算查询和键的点积,除以缩放因子 𝑑𝑘dk 并通过Softmax函数得到注意力权重:
Attention(𝑄,𝐾,𝑉)=Softmax(𝑄𝐾𝑇𝑑𝑘)𝑉Attention(Q,K,V)=Softmax(dkQKT)V
-
多头注意力:将多个头的输出拼接起来,并通过线性变换得到最终的输出:
MultiHead(𝑄,𝐾,𝑉)=Concat(head1,head2,…,headℎ)𝑊𝑂MultiHead(Q,K,V)=Concat(head1,head2,…,headh)WO
其中每个头的计算方式相同,但具有不同的参数。
Masked Multi-Head Attention
Masked Multi-Head Attention特别用于Transformer模型的解码器部分,尤其在自回归生成任务中,例如文本生成。它在计算注意力权重时引入了掩码,以确保模型在生成下一个单词时,只能访问到当前单词及其之前的单词,而不能看到未来的信息。
掩码(Masking)
掩码的主要作用是屏蔽未来的时间步,防止模型在预测某个时间步的输出时访问到未来的单词,从而保持自回归的性质。这是通过在计算注意力权重时引入一个掩码矩阵实现的。
-
掩码矩阵:创建一个上三角掩码矩阵 𝑀M,其形式如下:
𝑀𝑖𝑗={0if 𝑗≤𝑖−∞if 𝑗>𝑖Mij={0−∞if j≤iif j>i
该矩阵确保在计算注意力时,位置 𝑖i 只能看到位置 𝑖i 及其之前的位置。
-
应用掩码:在计算注意力权重之前,将点积结果加上掩码矩阵 𝑀M:
Attention(𝑄,𝐾,𝑉)=Softmax(𝑄𝐾𝑇+𝑀𝑑𝑘)𝑉Attention(Q,K,V)=Softmax(dkQKT+M)V
这样,通过将不该访问的未来位置的得分变为负无穷,在Softmax计算中它们的权重将变为零。
具体步骤
- 线性变换:将输入序列通过线性变换得到多个头的查询、键和值矩阵。
- 计算掩码注意力:对每个头应用掩码后计算注意力:
MaskedAttention(𝑄,𝐾,𝑉)=Softmax(𝑄𝐾𝑇+𝑀𝑑𝑘)𝑉MaskedAttention(Q,K,V)=Softmax(dkQKT+M)V
- 多头拼接和线性变换:将多个头的输出拼接后,通过线性变换得到最终输出。
应用场景
Masked Multi-Head Attention广泛用于:
- 文本生成:如机器翻译、文本摘要、对话系统等,自回归生成下一个单词时需要屏蔽未来信息。
- 序列预测:任何需要保证序列预测严格按时间步进行的任务。
总结
Masked Multi-Head Attention通过引入掩码机制,确保在自回归任务中模型只能访问历史信息,从而保证预测的顺序性。结合多头注意力,它能够有效捕捉输入序列中的不同层次和不同位置的上下文信息,提高模型的灵活性和准确性。这一机制是Transformer及其变体成功的关键因素之一。