超平实版Pytorch Self-Attention: 参数详解(尤其是mask)(使用nn.MultiheadAttention)

本文介绍了PyTorch中实现Self-Attention的细节,特别是nn.MultiheadAttention模块的使用。通过query、key、value进行前向传播,输出包括attention输出和权重。讨论了mask的两种类型——key_padding_mask和attn_mask,解释了它们在处理padding和避免未来信息泄漏中的作用,并阐述了如何通过-inf使得注意力权重接近于0。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Self-Attention的结构图

本文侧重于Pytorch中对self-attention的具体实践,具体原理不作大量说明,self-attention的具体结构请参照下图。
在这里插入图片描述
(图中为输出第二项attention output的情况,k与q为key、query的缩写)

本文中将使用Pytorch的torch.nn.MultiheadAttention来实现self-attention.

forward输入中的query、key、value

首先,前三个输入是最重要的部分query、key、value。由图1可知,我们self-attention的这三样东西其实是一样的,它们的形状都是:(L,N,E) 1

L:输入sequence的长度(例如一个句子的长度)
N:批大小(例如一个批的句子个数)
E:词向量长度

forward的输出

输出的内容很少只有两项:

  1. attn_output
    即通过self-attention之后,从每一个词语位置输出来的attention。其形状为(L,N,E),是和输入的query它们形状一样的。因为毕竟只是给value乘了一个weight。

  2. attn_output_weights
    即attention weights,形状是(N,L,L),因为每一个单词和任意另一个单词之间都会产生一个weight,所以每一句句子的weight数量是L*L

实例化一个nn.MultiheadAttention

这里对MultiheadAttention进行一个实例化并传入一些参数,实例化之后我们得到的东西我们就可以向它传入input了。

实例化时的代码:

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)

其中,embed_dim是每一个单词本来的词向量长度;num_heads是我们MultiheadAttention的head的数量。

pytorch的MultiheadAttention应该使用的是Narrow self-attention机制,即,把embedding分割成num_heads份,每一份分别拿来做一下attention。

也就是说:单词1的第一份、单词2的第一份、单词3的第一份…会当成一个sequence,做一次我们图1所示的self-attention。
然后,单词1的第二份、单词2的第二份、单词3的第二份…也会做一次
直到单词1的第num_heads份、单词2的第num_heads份、单词3的第num_heads份…也做完self-attention

从每一份我们都会得到一个(L,N,E)形状的输出,我们把这些全部concat在一起,会得到一个(L,N,E*num_heads)的张量。

这时候,我们拿一个矩阵,把这个张量的维度变回(L,N,E)即可输出。

进行forward操作

我们把我们刚才实例化好的multihead_attn拿来进行forward操作(即输入input得到output):

attn_output, attn_output_weights = multihead_attn(query, key, value)

关于mask

mask可以理解成遮罩、面具,作用是帮助我们“遮挡”掉我们不需要的东西,即让被遮挡的东西不影响我们的attention过程。

在forward的时候,有两个mask参数可以设置:

  1. key_padding_mask
    每一个batch的每一个句子的长度一般是不可能完全相同的,所以我们会使用padding把一些空缺补上。而这里的这个key_padding_mask是用来“遮挡”这些padding的。
    这个mask是二元(binary)的,也就是说,它是一个矩阵和我们key的大小是一样的,里面的值是1或0,我们先取得key中有padding的位置,然后把mask里相应位置的数字设置为1,这样attention就会把key相应的部分变为"-inf". (为什么变为-inf我们稍后再说)

  2. attn_mask
    这个mask经常是用来遮挡“正确答案”的:
    假如你想要用这个模型每次预测下一个单词,我们每一个位置的attention输出是怎么得来的?是不是要看一遍整个序列,然后每一个单词都计算一个attention weight?那也就是说,你在预测第5个词的时候,你其实会看到整个序列,这样的话你在预测之前不就已经知道第5个单词是什么了,这就是作弊了。
    我们不想让模型作弊,因为在真实使用这个模型去预测的时候,我们是没有整个序列的信息的。那么怎么办?那就让第5个单词的attention weight=0吧,即声明:我不想看这个单词,我的注意力一点也别分给它。

如何让这个weight=0:
我们先想象一下,我们目前拥有的attention scores是什么样的?(注:attention_score是attention_weight的初始样子,经过softmax之后会变成attention_weight.
attention_score和weight的形状是一样的,毕竟只有一个softmax的差别)

我们之前提到,attention weights的形状是L*L,因为每个单词两两之间都有一个weight。

如下图所示,我用蓝笔圈出的部分,就是“我想要预测x2”时,整个sequence的attention score情况。我用叉划掉的地方,是我们希望=0的位置,因为我们想让x2、x3、x4的权值为0,即:预测x2的时候,我们的注意力只能放在x1上。
在这里插入图片描述
对于其他行,你可以以此类推,发现我们需要一个三角形区域的attention weight=0, 这时候我们的attn_mask这时候就出场了,把这个mask做成三角形即可。

关于mask的题外话:有朋友好奇为什么有的地方看到的图mask了对角线有的没有,应该是因为sequence不同或者训练任务/方式不同,但本质上mask的原理是一样的。我再找一张图帮助大家理解,比如如果加上s(start)和e(end)的话就是类似这样:(图源2, 白色为mask掉的部分)
请添加图片描述

现在我们来说mask的值。和key_padding_mask不同,我们的attn_mask不是binary的,而是一个“additive mask”。

什么是additive mask呢?就是我们mask上设置的值,会被加到我们原本的attention score上。我们要让三角形区域的weight=0,我们这个三角mask设置什么值好呢?答案是-inf,(这个-inf在key_padding_mask的讲解中也出现了,这里就来说说为什么要用-inf)。我们上面提到了,attention score要经过一个softmax才变成attention_weights.
我们都知道softmax的式子可以表示为3
σ ( z ) j = e z j ∑ k = 1 K e z k \sigma(z)_j = \frac{e^{z_j}}{\sum_{k=1}^{K}e^{z_k}} σ(z)j=k=1Kezkezj
(for j = 1,…,K)

当我们attention score的值设置为-inf (可以看作这里式子里的 z j = − inf ⁡ z_j=-\inf zj=inf),于是通过softmax之后我们的attention weight就会趋近于0了,这就是为什么我们这里的两个mask都要用到-inf。

Reference


  1. Pytorch官方Documentation ↩︎

  2. https://flashgene.com/archives/63944.html ↩︎

  3. Wikipedia: Softmax函数 ↩︎

### torch.nn.MultiheadAttention 中的注意力得分计算 `torch.nn.MultiheadAttention` 是 PyTorch 提供的一个用于实现多头注意力机制的模块。其核心功能之一是通过 `scaled_dot_product_attention` 函数来计算注意力得分,这一过程基于 Scaled Dot-Product Attention 的定义。 #### 注意力得分公式的理论基础 在多头注意力机制中,注意力得分的计算依赖于 Query (Q) 和 Key (K) 向量之间的相似性评估。具体来说,注意力得分为: \[ \text{Attention Score} = \frac{\text{Query} \cdot \text{Key}^\top}{\sqrt{d_k}} \] 其中 \( d_k \) 表示 Key 向量的维度大小[^1]。这个公式的作用是对输入序列的不同位置进行加权评分,权重反映了不同位置的相关程度。 #### 实现细节分析 为了更清楚地理解 `torch.nn.MultiheadAttention` 如何计算注意力得分,可以从以下几个方面展开讨论: 1. **线性变换** 输入数据会被分别映射到 Query、Key 和 Value 空间。假设输入张量为 X,则有: \[ Q = W_Q \cdot X, \quad K = W_K \cdot X, \quad V = W_V \cdot X \] 这里的 \( W_Q, W_K, W_V \) 都是由模型学习得到的可训练参数矩阵[^2]。 2. **缩放点积注意力** 使用上述公式计算注意力得分之后,会对其进行 softmax 归一化处理,最终形成概率分布形式的注意力权重。完整的 scaled dot-product attention 可表示为: \[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \] 此外,在实际应用中可能会加入掩码(mask),以便屏蔽掉某些不需要关注的位置[^3]。 3. **并行加速与 GPU 支持** 自注意力机制虽然具有 O() 时间复杂度,但由于每一对查询键值对之间相互独立,因此非常适合利用现代硬件设备(如 GPU 或 TPU)来进行批量化的并行运算。这种特性显著提升了整体效率[^4]。 ```python import torch from torch import nn # 初始化 Multi-head Attention 层 multihead_attn = nn.MultiheadAttention(embed_dim=128, num_heads=8) # 构造随机输入张量作为例子 query = key = value = torch.randn(30, 10, 128) # shape: [seq_len, batch_size, embed_dim] # 执行前向传播获取输出以及注意力建议图谱 output, attn_weights = multihead_attn(query=query, key=key, value=value) print(output.shape) # 输出形状应等于 query 的形状 print(attn_weights.shape) # 返回的是注意力权重矩阵 ``` 以上代码片段展示了如何实例化一个多头注意力层,并传入相应的 Query、Key 和 Value 数据完成一次正向传递操作。
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值