【大规模语言模型:从理论到实践】Transformer中MultiHeadAttention详解

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads  # 每个头的维度,d_model 除以 heads
        self.h = heads  # 多头的数量
        
        # 定义线性变换层,用于将输入分别映射到 q, k, v 空间
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        
        # 定义 dropout 层,用于防止过拟合
        self.dropout = nn.Dropout(dropout)
        
        # 输出线性层,用于将多头注意力拼接后的结果映射回 d_model 维度
        self.out = nn.Linear(d_model, d_model)

    # 注意力机制函数,用于计算 q, k, v 的注意力分数,并应用 softmax 和 dropout
    def attention(q, k, v, d_k, mask=None, dropout=None):
        # 计算点积注意力分数,q 和 k 点积并除以 d_k 的平方根以防止梯度消失
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        
        # 如果有掩码,则将填充位置的分数设为 -1e9,使得它们在 softmax 中接近 0
        if mask is not None:
            mask = mask.unsqueeze(1)  # 扩展维度,使其与 scores 形状匹配
            scores = scores.masked_fill(mask == 0, -1e9)  # 掩盖填充位置
        
        # 对 scores 应用 softmax,使其成为概率分布
        scores = F.softmax(scores, dim=-1)
        
        # 如果有 dropout,应用 dropout 以增强正则化
        if dropout is not None:
            scores = dropout(scores)
        
        # 计算注意力输出,将注意力权重和 value 相乘
        output = torch.matmul(scores, v)
        return output

    # 前向传播函数,计算多头注意力的输出
    def forward(self, q, k, v, mask=None):
        bs = q.size(0)  # 获取 batch size
        
        # 对 query, key, value 进行线性变换,并 reshape 为 (batch_size, seq_len, heads, d_k)
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        # 将第二维(序列长度)和第三维(heads)进行转置,便于后续计算注意力
        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)

        # 计算注意力分数,使用上面定义的 attention 函数
        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)

        # 将多个 heads 的输出结果拼接起来,形状恢复为 (batch_size, seq_len, d_model)
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)

        # 最后通过线性层得到最终的输出
        output = self.out(concat)

        return output
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值