深入解析PyTorch中MultiheadAttention的参数key_padding_mask与attn_mask

1. 基本背景

在multiheadattention中存在两个mask,一个参数是key_padding_mask,另外一个是attn_mask,尽管这两个参数是被人们所熟知的填充掩码和注意力掩码,但是深度理解以便清晰区分对于深刻理解该架构非常重要。

2. 参数Key_padding_mask(关键填充掩码)

  • 用途:防止模型关注到输入序列中用 <pad> 填充的位置。
  • 场景:对变长输入进行 padding 后,避免注意力将注意力权重分配到 padding token 上。
  • 应用位置:在计算注意力时,对 所有 query 的 key 位置 进行屏蔽。

✅维度

# key_padding_mask shape: (batch_size, seq_len)

✅ 示例

key_padding_mask = torch.tensor([[False, False, True], [False, True, True]])
# 表示第一个样本第3个位置是pad,第二个样本第2,3个位置是pad

3. 参数Attn_mask(注意力掩码)

  • 用途:对注意力矩阵中任意 query-key 对的连接进行屏蔽,更灵活。
  • 场景:
    • Transformer 解码器中的 自回归遮蔽(causal mask)
    • 限定注意力只能在局部范围内滑动(局部注意力)
    • 自定义 mask,如节省计算或实验结构

✅ 维度

# [tgt_len, src_len](用于所有 batch 和 head)
# 或 [batch_size * num_heads, tgt_len, src_len](用于每个 head 的个性化 mask)

✅ 示例:causal mask

# 上三角为 True,代表“未来的信息被屏蔽”,用于解码器自回归。
tgt_len = 5
attn_mask = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1).bool()

4. 工作流程中的区别⚠️⚠️⚠️

在计算 Q ∗ K T Q*K^T QKT之后:

  1. 先应用attn_mask(对齐注意力矩阵维度,屏蔽某些query-key配对);
  2. 再应用key_padding_mask(对每个样本的padding key屏蔽);
  3. 最后经过softmax处理

5. 类比理解

  • key_padding_mask 像是说:“这些 token 是 padding,不用关注它们。”
  • attn_mask 像是说:“这些 query-key 配对不允许有连接(比如未来的信息)。”
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值