一、目录
- 定义
- 代码实现
二、实现
- 定义
- 代码实现
#手撕 self attention
import torch
import torch.nn as nn
import numpy as np
class SelfAttention(nn.Module):
def __init__(self,hidden_dim,dim_q,dim_v):
super(SelfAttention,self).__init__()
self.hidden_dim=hidden_dim
self.dim_q=dim_q
self.dim_k=dim_q
self.dim_v=dim_v
self.linear_q=nn.Linear(self.hidden_dim,self.dim_q)
self.linear_k=nn.Linear(self.hidden_dim,self.dim_k)
self.linear_v=nn.Linear(self.hidden_dim,self.dim_v)
self.norm_fact=1/np.sqrt(self.dim_k) #保持均值、方差不变,使得训练过程中梯度值保持稳定
def forward(self,x):
q=self.linear_q(x) #为了提升模型的拟合能力,矩阵W都是可以训练的,起到一个缓冲的效果。
k=self.linear_k(x)
v=self.linear_v(x)
acore=torch.matmul(q,k.transpose(1,2))*self.norm_fact
#内积:以行向量的角度理解,里面保存了每个向量与自己和其他向量进行内积运算的结果,代表词的相关性
a=torch.softmax(acore,dim=-1)
att=torch.matmul(a,v)
return att
if __name__ == '__main__':
batch=2
seq_len=5
hidden_dim=4
x=torch.randn(batch,seq_len,hidden_dim)
attention=SelfAttention(hidden_dim,10,hidden_dim)
print(attention(x).shape)