题目
'''
Description: attention注意力机制
Autor: 365JHWZGo
Date: 2021-12-14 17:06:11
LastEditors: 365JHWZGo
LastEditTime: 2021-12-14 22:23:54
'''
注意力机制三步式+分步代码讲解
导入库
import torch
import torch.nn as nn
import torch.nn.functional as F
Attn类
class Attn(nn.Module):
def __init__(self,query_size,key_size,value_size1,value_size2):
super(Attn,self).__init__()
self.query_size = query_size
self.key_size = key_size
self.value_size1 = value_size1
self.value_size2 = value_size2
self.attn = nn.Linear(self.query_size+self.key_size,value_size1)
def forward(self,q,k,v):
# attn_weights=(1,32)
attn_weights = F.softmax(self.attn(torch.concat((q[0],k[0]),1)),dim=1)
# attn_weights.unsqueeze(0)=(1,1,32)
# v=(1,32,64)
# attn_applied=(1,1,64)
output = torch.bmm(attn_weights.unsqueeze(0),v)
return output,attn_weights
attn函数是将合成【Query|Key】,进行列合并
f
(
Q
,
K
)
=
W
a
[
Q
,
K
]
f(Q,K) = W_a[Q,K]
f(Q,K)=Wa[Q,K]
attn_weights的结果对应于a1,a2,a3…
output是计算Attention Value,bmm相当于a1value1+a2value2+…【矩阵乘法】
if __name__ == "__main__":
query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
attn = Attn(query_size, key_size, value_size1, value_size2)
Q = torch.randn(1,1,32)
K = torch.randn(1,1,32)
V = torch.randn(1,32,64)
out = attn(Q, K ,V)
print(out[0])
print(out[1])