- 看见网上有要手撕注意力的面经,自己开始写,以为很简单,实际上自己菜的要死,遂写本博客。
公式
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K t d ) V Attention(Q,K,V) = softmax(\frac{QK^t}{\sqrt d})V Attention(Q,K,V)=softmax(dQKt)V
代码
def myattention(q, k, v, d_k, mask=None, dropout=None):
# 首先计算注意力得分
score = torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(d_k)
# 处理mask信息,如果被mask就变成一个很小的数
if mask is not None:
mask = mask.unsqueeze(1)
score = score.masked_fill(mask == 0,-1e9)
# 对注意力得分进行softmax
score = F.softmax(score, dim=-1) # 沿着最后一个维度计算 其他维度不变
# 处理dropout
if dropout is not None:
score = dropout(score)
# 点乘value矩阵
ans = torch.matmul(score,v)
return ans