使用的是因果mask
import torch.nn as nn
import torch
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, seq_len, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out 必须能够被 num_heads 整除"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # 确保每个头的输出维度是 d_out 的 num_heads 分之一
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # 线性层,用于生成查询矩阵
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) # 线性层,用于生成键矩阵
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # 线性层,用于生成值矩阵
self.out_proj = nn.Linear(d_out, d_out) # 线性层,用于将多头注意力输出结合在一起
self.dropout = nn.Dropout(dropout) # dropout层,防止过拟合
# 使用上三角矩阵创建mask,确保模型只关注到当前和之前的token,避免看到未来的信息
self.register_buffer('mask', torch.triu(torch.ones(seq_len, seq_len), diagonal=1))
def forward(self, x):
# x的形状:(batch_size, seq_len, d_in)
b, seq_len, d_in = x.shape # b: batch_size, seq_len: 序列长度, d_in: 输入维度
# 通过线性层生成查询、键和值矩阵,形状均为 (batch_size, seq_len, d_out)
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
# 将每个矩阵划分为多个头,增加一个 num_heads 维度,形状为 (batch_size, seq_len, num_heads, head_dim)
keys = keys.view(b, seq_len, self.num_heads, self.head_dim)
values = values.view(b, seq_len, self.num_heads, self.head_dim)
queries = queries.view(b, seq_len, self.num_heads, self.head_dim)
# 转置矩阵,使 num_heads 维度提前,形状变为 (batch_size, num_heads, seq_len, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# 计算缩放点积注意力,queries 和 keys 矩阵相乘,得到注意力得分,形状为 (batch_size, num_heads, seq_len, seq_len)
attn_scores = queries @ keys.transpose(-2, -1) # 键矩阵转置用于点积操作
# 将原始 mask 矩阵裁剪至实际的序列长度,并转换为布尔矩阵
mask_bool = self.mask.bool()[:seq_len, :seq_len]
# 使用 mask 将未来的 token 屏蔽掉,确保模型无法看到未来的信息
attn_scores.masked_fill_(mask_bool, -torch.inf)
# 计算注意力权重,使用 softmax 函数归一化,形状为 (batch_size, num_heads, seq_len, seq_len)
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
attn_weights = self.dropout(attn_weights) # 应用 dropout 以减少过拟合
# 将注意力权重与值矩阵相乘,得到上下文向量,形状为 (batch_size, num_heads, seq_len, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# 将多头的上下文向量合并,形状变为 (batch_size, seq_len, d_out)
context_vec = context_vec.reshape(b, seq_len, self.d_out)
# 可选:通过输出投影层进一步处理上下文向量
context_vec = self.out_proj(context_vec)
return context_vec
# 设置随机种子以确保结果可复现
torch.manual_seed(123)
# 模拟输入数据 (batch_size, seq_len, d_in),这里我们创建一个随机张量
batch_size = 3
seq_len = 5 # 实际序列长度
d_in = 10 # 输入维度
d_out = 8 # 输出维度
num_heads = 2 # 多头数量
dropout = 0.1 # dropout 概率
# 创建 MultiHeadAttention 实例
multi_head_attention = MultiHeadAttention(d_in, d_out, seq_len, dropout, num_heads)
# 输入的随机张量
x = torch.randn(batch_size, seq_len, d_in)
# 调用 forward 函数,输出上下文向量
output = multi_head_attention(x)
# 输出上下文向量
print("上下文向量 (context vectors):")
print(output)