import torch.nn as nn
import torch
import matplotlib.pyplot as plt
class Self_Attention(nn.Module):
def __init__(self, dim, dk, dv):
super(Self_Attention, self).__init__()
self.scale = dk ** -0.5
self.q = nn.Linear(dim, dk)
self.k = nn.Linear(dim, dk)
self.v = nn.Linear(dim, dv)
def forward(self, x):
q = self.q(x)
k = self.k(x)
v = self.v(x)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = attn @ v
return x
att = Self_Attention(dim=2, dk=2, dv=3)
x = torch.rand((1, 4, 2))
output = att(x)
代码中首先创建了一个self-attention的类,然后在随机出输入x,(1,4,2) 1是batch_size 4是4个token,2是每个token的长度
然后把x传入对象,在self-attention的初始化函数中定义了 1 d k \frac{1}{\sqrt{d_k}} dk1,并且从输入中提出q,k,v,其中q,k的维度是一定要保持一致的,
这里面 X的输入是(batchsize,num,dim_in)num是一维序列中token的个数,这里a1到a4就4个,dim_in是每个token的特征维数,这里每一个a都是1*2的向量,特征维度为2,dim_in就为2
对于Q、K、V的维度,W1 W2 W3分别是(dim_in,dq) (dim_in,dk) (dim_in,dv) 只不过dq肯定等于dk
这样X(batchsize,num,dim_in)才能分别与W1 、W2、W3相乘得到Q、K、V
Q(batchsize,num,dq) K(batchsize,num,dk) V(batchsize,num,dv)
这样
Q
K
T
QK^T
QKT (batchsize,num,dq)*(batchsize,dk,num) = (batchsize,num,num)
K T K^T KT是依靠k.transpose(-2,-1)实现的
self.scale = dk ** -0.5这一步是计算 1 d k \frac{1}{\sqrt{d_k}} dk1
self.q = nn.Linear(dim, dk)
self.k = nn.Linear(dim, dk)
self.v = nn.Linear(dim, dv)
torch.nn.Linear的用法可以参考官方的文档
这里的三个操作实际上就构建了W1,W2,W3矩阵,这三个矩阵分别与X相乘就得到了Q K V
def forward(self, x):
q = self.q(x)
k = self.k(x)
v = self.v(x)
这里就完成了Q K V的计算
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = attn @ v
这三步就分别完成了相似度分数的计算、相似度分数的归一化和最终计算