torch.nn.MultiheadAttention
是 PyTorch 中实现多头注意力机制(Multi-head Attention)的模块,通常用于神经网络中的注意力机制,如 Transformer 模型。
模块的功能
MultiheadAttention
实现了多头注意力机制,其中每个注意力头独立学习不同的表示,然后将其组合。- 它接收查询(query)、键(key)和值(value)作为输入,并返回注意力输出。
- 支持掩码(mask),可用于处理序列中的填充(padding)或添加注意力约束。
类定义
torch.nn.MultiheadAttention(
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
kdim: Optional[int] = None,
vdim: Optional[int] = None,
batch_first: bool = False,
)
参数说明
embed_dim
:嵌入的维度(必须是num_heads
的整数倍)。num_heads
:注意力头的数量。dropout
:注意力分数中的 dropout 概率(默认0.0
)。bias
参数控制的是 Q、K、V 和输出投影线性变换中的偏置项。默认情况下(bias=True
),这些线性变换包含偏置。add_bias_kv
:是否为键和值添加可学习的偏置(默认False
)。add_zero_attn
:是否在注意力权重计算前添加全零向量(默认False
)。kdim
和vdim
:键和值的特征维度(如果未指定,则与embed_dim
相同)。batch_first
:是否使用[batch, seq, embed_dim]
格式的输入(默认False
)。
方法
forward
MultiheadAttention.forward(
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]
query
、key
、value
:输入张量,形状为[seq_len, batch_size, embed_dim]
(如果batch_first=True
,则[batch_size, seq_len, embed_dim]
)。