手撕多头注意力机制

使用的是因果mask

import torch.nn as nn
import torch


class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, seq_len, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out 必须能够被 num_heads 整除"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # 确保每个头的输出维度是 d_out 的 num_heads 分之一

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)  # 线性层,用于生成查询矩阵
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)  # 线性层,用于生成键矩阵
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)  # 线性层,用于生成值矩阵
        self.out_proj = nn.Linear(d_out, d_out)  # 线性层,用于将多头注意力输出结合在一起
        self.dropout = nn.Dropout(dropout)  # dropout层,防止过拟合

        # 使用上三角矩阵创建mask,确保模型只关注到当前和之前的token,避免看到未来的信息
        self.register_buffer('mask', torch.triu(torch.ones(seq_len, seq_len), diagonal=1))

    def forward(self, x):
        # x的形状:(batch_size, seq_len, d_in)
        b, seq_len, d_in = x.shape  # b: batch_size, seq_len: 序列长度, d_in: 输入维度

        # 通过线性层生成查询、键和值矩阵,形状均为 (batch_size, seq_len, d_out)
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # 将每个矩阵划分为多个头,增加一个 num_heads 维度,形状为 (batch_size, seq_len, num_heads, head_dim)
        keys = keys.view(b, seq_len, self.num_heads, self.head_dim)
        values = values.view(b, seq_len, self.num_heads, self.head_dim)
        queries = queries.view(b, seq_len, self.num_heads, self.head_dim)

        # 转置矩阵,使 num_heads 维度提前,形状变为 (batch_size, num_heads, seq_len, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # 计算缩放点积注意力,queries 和 keys 矩阵相乘,得到注意力得分,形状为 (batch_size, num_heads, seq_len, seq_len)
        attn_scores = queries @ keys.transpose(-2, -1)  # 键矩阵转置用于点积操作

        # 将原始 mask 矩阵裁剪至实际的序列长度,并转换为布尔矩阵
        mask_bool = self.mask.bool()[:seq_len, :seq_len]

        # 使用 mask 将未来的 token 屏蔽掉,确保模型无法看到未来的信息
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # 计算注意力权重,使用 softmax 函数归一化,形状为 (batch_size, num_heads, seq_len, seq_len)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)  # 应用 dropout 以减少过拟合

        # 将注意力权重与值矩阵相乘,得到上下文向量,形状为 (batch_size, num_heads, seq_len, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # 将多头的上下文向量合并,形状变为 (batch_size, seq_len, d_out)
        context_vec = context_vec.reshape(b, seq_len, self.d_out)

        # 可选:通过输出投影层进一步处理上下文向量
        context_vec = self.out_proj(context_vec)

        return context_vec


# 设置随机种子以确保结果可复现
torch.manual_seed(123)

# 模拟输入数据 (batch_size, seq_len, d_in),这里我们创建一个随机张量
batch_size = 3
seq_len = 5  # 实际序列长度
d_in = 10  # 输入维度
d_out = 8  # 输出维度
num_heads = 2  # 多头数量
dropout = 0.1  # dropout 概率

# 创建 MultiHeadAttention 实例
multi_head_attention = MultiHeadAttention(d_in, d_out, seq_len, dropout, num_heads)

# 输入的随机张量
x = torch.randn(batch_size, seq_len, d_in)

# 调用 forward 函数,输出上下文向量
output = multi_head_attention(x)

# 输出上下文向量
print("上下文向量 (context vectors):")
print(output)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值