多头注意力模块 (Multi-Head Attention, MHA) 代码实现(pytorch)

多头注意力是Transformer架构中的关键组件,它允许模型在处理序列数据时,同时关注序列中的多个位置,以下为该模块的简单实现:

导入依赖

import torch
import torch.nn as nn
import torch.nn.functional as F

导入所需的PyTorch库。

定义多头注意力类

class MultiHeadAttention(nn.Module):

定义一个继承自nn.Module的类MultiHeadAttention

初始化方法

    def __init__(self, embed_dim, num_heads, dropout=0.1):

初始化方法接受三个参数:

  • embed_dim:嵌入维度。
  • num_heads:多头注意力中头的数量。
  • dropout:dropout比率。

属性定义

        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

定义四个线性层,分别用于查询(Q)、键(K)、值(V)的变换和最终输出的投影。同时定义一个dropout层。

前向传播方法

    def forward(self, query, key, value, mask=None):

前向传播方法接受查询、键、值和掩码作为输入。

线性变换

        Q = self.q_linear(query)
        K = self.k_linear(key)
        V = self.v_linear(value)

对查询、键、值进行线性变换。

重塑形状

        Q = Q.view(batch_size, -1, self.num_heads, self.heads_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.heads_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.heads_dim).transpose(1, 2)

将变换后的张量重塑并交换维度,以适应多头注意力的计算。

计算注意力分数

        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / (self.heads_dim ** 0.5)

计算查询和键的点积,并应用缩放因子。

应用掩码

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

如果提供了掩码,则将掩码中为0的位置的注意力分数设置为负无穷大。

应用Softmax和Dropout

        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.dropout(attn_probs)

对注意力分数应用softmax函数进行归一化,并应用dropout。

计算加权值

        attn_output = torch.matmul(attn_probs, V)

使用归一化的注意力分数和值计算加权和。

重塑输出

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)

将加权和的张量重新塑形。

应用输出线性层

        output = self.out_proj(attn_output)

对重塑后的张量应用输出线性层。

返回输出

        return output

返回多头注意力的输出结果。

示例使用

展示如何实例化和使用MultiHeadAttention模块。

embed_dim = 256
num_heads = 8
mha = MultiHeadAttention(embed_dim, num_heads)

dummy_query = torch.rand(10, 50, embed_dim)
dummy_key = torch.rand(10, 50, embed_dim)
dummy_value = torch.rand(10, 50, embed_dim)

output = mha(dummy_query, dummy_key, dummy_value)
print(output.shape)  # 应该输出形状 [batch_size, seq_length, embed_dim]

个人水平有限,有问题随时交流~

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值