import torch
class selfattention:
def __init__(self,X):
self.X = X # 词向量
def attention(self,d):
n = self.X.size()[1]
WQ = torch.nn.Linear(n,d)
WK = torch.nn.Linear(n,d)
WV = torch.nn.Linear(n,d)
Q = WQ(self.X)
K = WK(self.X)
V = WV(self.X)
att = torch.matmul(torch.softmax(torch.matmul(Q, K.T)/Q.size()[1], dim=1), V)
return att
if __name__== 'main':
selfattention(torch.rand(10,500)).attention(10)
self-attention(自注意力)pytorch代码实现
于 2023-11-03 13:47:56 首次发布