MultiHeadAttention 是 Transformer 模型中的一个核心组件,它允许模型在处理序列的每个位置时同时考虑来自多个“视角”(即头部)的信息。这样做可以提高模型对不同位置关系的理解能力。
重点讲解
主要步骤:
- 线性变换得到QKV,并将QKV分割为多头
- 计算缩放点积注意力(注意mask可选)
- 拼接多头
- 最后再进行一次线性变换
代码实现
下面,我将使用 PyTorch 框架实现一个基本的 MultiHeadAttention
模块。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.depth = d_model // num_heads
# 定义线性层和输出线性层
self.query_linear = nn.Linear(d_model, d_model)
self.key_linear = nn.Linear(d_model, d_model)
self.value_linear = nn.Linear(d_model, d_model)
self.final_linear = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
"""分割最后一个维度到 (num_heads, depth).
转置结果使得形状为 (batch_size, num_heads, seq_length, depth)
"""
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. 线性层和分割到多头
query = self.split_heads(self.query_linear(query), batch_size)
key = self.split_heads(self.key_linear(key), batch_size)
value = self.split_heads(self.value_linear(value), batch_size)
# 2. 缩放点积注意力
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.depth)
if mask is not None:
scores = scores.masked_fill(mask == True, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
# 3. 将注意力权重应用到值上
output = torch.matmul(attention_weights, value)
# 4. 连接头部
output = output.permute(0, 2, 1, 3).contiguous()
output = output.view(batch_size, -1, self.d_model)
# 5. 最后一次线性变换
output = self.final_linear(output)
return output
流程图(维度变换示意图)
self-attention示例
d_model = 512 # 模型维度
num_heads = 8 # 头数
mha = MultiHeadAttention(d_model, num_heads)
# 创建随机数据
batch_size = 4
seq_length = 60
x = torch.rand(batch_size, seq_length, d_model) # 输入假设维度为 (batch_size, seq_length, d_model)
output = mha(x, x, x) # 自注意力机制,qkv的输入相同;而cross-attention中,query来自decoder,kv来自encoder
print(output.shape)
加入mask示例
解码器的自注意力层需要确保当前位置只能注意到前面的位置(包括当前位置),而不是未来的位置。这通常通过一个未来位置掩码实现,它是一个下三角矩阵。
import torch
def generate_square_subsequent_mask(seq_len):
"""生成一个未来步骤掩码,用于解码器中防止看到未来信息。"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() # diagonal 控制对角线开始的位置
return mask
d_model = 512 # 模型维度
num_heads = 8 # 头数
mha = MultiHeadAttention(d_model, num_heads)
# 创建随机数据
batch_size = 4
seq_length = 60
x = torch.rand(batch_size, seq_length, d_model) # 输入假设维度为 (batch_size, seq_length, d_model)
# 生成掩码并将其应用于解码器的自注意力层
future_mask = generate_square_subsequent_mask(seq_length).to(x.device)
output = mha(x, x, x, mask=future_mask) # 自注意力机制
print(output.shape) # 应为 (batch_size, seq_length, d_model)
注意广播机制
在 PyTorch 中,masked_fill
函数可以很灵活地处理维度差异情况,通过广播(broadcasting)机制来匹配维度。