torch.nn.MultiheadAttention模块介绍

torch.nn.MultiheadAttention 是 PyTorch 中实现多头注意力机制(Multi-head Attention)的模块,通常用于神经网络中的注意力机制,如 Transformer 模型。


模块的功能

  • MultiheadAttention 实现了多头注意力机制,其中每个注意力头独立学习不同的表示,然后将其组合。
  • 它接收查询(query)、键(key)和值(value)作为输入,并返回注意力输出。
  • 支持掩码(mask),可用于处理序列中的填充(padding)或添加注意力约束。

类定义

torch.nn.MultiheadAttention(
    embed_dim: int,
    num_heads: int,
    dropout: float = 0.0,
    bias: bool = True,
    add_bias_kv: bool = False,
    add_zero_attn: bool = False,
    kdim: Optional[int] = None,
    vdim: Optional[int] = None,
    batch_first: bool = False,
)
参数说明
  • embed_dim:嵌入的维度(必须是 num_heads 的整数倍)。
  • num_heads:注意力头的数量。
  • dropout:注意力分数中的 dropout 概率(默认 0.0)。
  • bias 参数控制的是 Q、K、V 和输出投影线性变换中的偏置项默认情况下(bias=True),这些线性变换包含偏置
  • add_bias_kv:是否为键和值添加可学习的偏置(默认 False)。
  • add_zero_attn:是否在注意力权重计算前添加全零向量(默认 False)。
  • kdim 和 vdim:键和值的特征维度(如果未指定,则与 embed_dim 相同)。
  • batch_first:是否使用 [batch, seq, embed_dim] 格式的输入(默认 False)。

方法

forward
MultiheadAttention.forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]
  • querykeyvalue:输入张量,形状为 [seq_len, batch_size, embed_dim](如果 batch_first=True,则 [batch_size, seq_len, embed_dim])。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值