缩放点积注意力(Scaled Dot-Product Attention)是一种注意力机制,常用于Transformer模型中。
在Transformer模型中,Scaled Dot-Product Attention被用于实现Multi-Head Attention。具体来说,Multi-Head Attention将输入矩阵分别进行多个头的线性变换,然后对每个头的变换结果分别计算Scaled Dot-Product Attention,最后将每个头的Attention结果拼接在一起并通过一个线性变换输出。
Scaled Dot-Product Attention的计算方式如下:
- 计算Query矩阵Q、Key矩阵K的乘积,得到得分矩阵scores。
- 对得分矩阵scores进行缩放,即将其除以向量维度的平方根(np.sqrt(d_k))。
- 若存在Attention Mask,则将Attention Mask的值为True的位置对应的得分矩阵元素置为负无穷(-inf)。
- 对得分矩阵scores进行softmax计算,得到Attention权重矩阵attn。
- 计算Value矩阵V和Attention权重矩阵attn的乘积,得到加权后的Context矩阵。
这是一段缩放点积注意力的代码:
class ScaledDotProductAttention(nn.Module):
def __init__(self, mask_value=-1e9):
super(ScaledDotProductAttention, self).__init__()
self.mask_value = mask_value
def forward(self, Q, K, V, attn_mask=None):
'''
Q: [batch_size, n_heads, len_q, d_k]
K: [batch_size, n_heads, len_k, d_k]
V: [batch_size, n_heads, len_v(=len_k), d_v]
attn_mask: [batch_size, n_heads, seq_len, seq_len]
'''
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]
if attn_mask is not None:
scores.masked_fill_(attn_mask, self.mask_value)
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
return context, attn
核心代码是这一句:
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
将scores除以d_k的平方根(np.sqrt(d_k)),这就是所谓的缩放,可以避免得分过大或过小。
通过这种方式,Scaled Dot-Product Attention可以计算出Query和Key之间的相似度,同时考虑了Value矩阵对最终结果的影响,进而实现了注意力机制的作用。