def attention(query, key, value, mask=None, dropout=None):
# shape:query=key=value---->[batch_size,8,max_length,64]
d_k = query.size(-1)
# k的纬度交换后为:[batch_size,8,64,max_length]
# scores的纬度为:[batch_size,8,max_length,max_length]
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(d_k)
#padding mask
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)#剖析点1
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
源码剖析
剖析点1:scores.masked_fill(mask == 0, -1e9)
Transformer的Scaled dot product部分的scores.masked_fill(mask == 0, -1e9)把mask矩阵中为0的位置用-1e9这个极小值填充,这样的极小值在经过softmax之后为0
关于masked_fill函数参考:https://blog.csdn.net/qq_41568188/article/details/107281395
掩码操作,用value填充tensor中与mask中值为0位置相对应的元素。
mask的形状必须与要填充的tensor形状一致。
import torch
import torch.nn.functional as F
a = torch.randn(5,6)
print(a)
x = [5,4,3,2,1]
mask = torch.zeros(5,6,dtype=torch.int)
for e_id, src_len in enumerate(x):
mask[e_id, src_len:] = 1
print(mask)
a.data.masked_fill_(mask==0,-float('inf'))
print(a)
print(F.softmax(a,dim=-1))
#输出
tensor([[-1.8453, -0.7031, 0.0066, -1.0771, -0.5282, -0.1669],
[ 1.0285, -0.1086, -0.9871, -1.2061, -0.7845, 1.5072],
[ 1.8313, 2.4513, -0.1615, -1.2768, -0.5887, -1.2990],
[ 0.4653, 0.7976, 0.2020, -0.0886, -0.9101, -2.9927],
[-0.5556, -0.5319, -0.1768, -0.4238, 1.2213, -1.9120]])
tensor([[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1]], dtype=torch.int32)
tensor([[ -inf, -inf, -inf, -inf, -inf, -0.1669],
[ -inf, -inf, -inf, -inf, -0.7845, 1.5072],
[ -inf, -inf, -inf, -1.2768, -0.5887, -1.2990],
[ -inf, -inf, 0.2020, -0.0886, -0.9101, -2.9927],
[ -inf, -0.5319, -0.1768, -0.4238, 1.2213, -1.9120]])
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0918, 0.9082],
[0.0000, 0.0000, 0.0000, 0.2520, 0.5015, 0.2465],
[0.0000, 0.0000, 0.4722, 0.3531, 0.1553, 0.0194],
[0.0000, 0.1045, 0.1491, 0.1165, 0.6035, 0.0263]])