代码实现和注释如下
#导入所需要的软件包
import torch
import torch.nn as nn
import math
#定义多头注意力的类,并继承nn.Module基类
class MultiHead(nn.Module):
#初始化函数,这里初始化时引入n_model:Embedding之后的向量维度,n_head:有几个头
def __init__(self, n_model, n_head):
super(MultiHead, self).__init__() #继承父类的初始化函数
self.n_model = n_model
self.n_head = n_head
self.w_q = nn.Linear(n_model,n_model) #定义Query(q)向量
self.w_k = nn.Linear(n_model,n_model) #定义Key(k)向量
self.w_v = nn.Linear(n_model,n_model) #定义Value(v)向量
self.combine = nn.Linear(n_model,n_model) #定义一个combine向量便于后面的输出处理
#定义softmax函数,dim=-1意指在最后一个维度上进行softmax
self.softmax = nn.Softmax(dim=-1)
def forward(self, q, k, v): #定义前向函数,参数玮q,k,v向量
batch, time, dimension = q.shape() #这里q,k,v的三维分别是batch, time, dimension
n_d = self.n_model//self.n_head #计算多头注意力中,分头之后各头可以分到的维度大小
q = self.w_q(q) #赋值q, k, v
k = self.w_k(k)
v = self.w_v(v)
#.view(batch, time, n_head, n_d)将q,k,v拆开为四个维度
#将原来的dimension维拆成分别为n_head 和 n_d大小的两个维度
#.permute(0,2,1,3)将原来的第一维和第二维调换,即原来是(0,1,2,3)调为(0,2,1,3)
q = q.view(batch, time, n_head, n_d).permute(0,2,1,3)
k = k.view(batch, time, n_head, n_d).permute(0,2,1,3)
v = v.view(batch, time, n_head, n_d).permute(0,2,1,3)
#计算注意力得分,按照公式,即Query和Key转置的点积再除以根号下n_d,
#经过激活函数之后再点乘Value
score = q @ k.transpose(2,3)/math.sqrt(n_d)
#定义mask掩码
#torch.ones(time, time, dtype = bool)会生成一个大小为time*time的全1矩阵
#torch.tril接受这个全1矩阵,并生成一个下三角的矩阵
#即对角线以下为1(TRUE),对角线以上为0(FALSE)
mask = torch.tril(torch.ones(time, time, dtype = bool))
#加上掩码,即对应掩码为0处的score的值置为负无穷,这样经过激活函数时便会变为0
score = score.masked_fill(mask==0, float("-inf"))
#经过softmax,并点成v得到注意力得分
score = self.softmax(score) @ v
#将注意力得分变换为原先输入矩阵大小
#.permute(0,2,1,3)交换第一维和第二维
#contiguous()确保矩阵内的数值连续
#.view(batch, time, dimension) 将矩阵变换为原输入大小
#本例中此时score大小为(256,64,128)
score = score.permute(0,2,1,3).contiguous().view(batch, time, dimension)
#输出矩阵,此时将d_model维度映射到d_model维度,这里重新映射的操作比较简单
#一般来说可以通过combine这个线性层维护这之前经过了激活函数的非线性
#不然整个注意力机制就是一个线性变换,模型学习不到特征
#同时,多引入了combine这层的参数,增加了参数多样性,
#提升了训练的稳定度和特征表示的学习能力
#重新映射也可以帮助调整输出权重,并且也便于之后添加归一化,dropout等操作,增加了灵活性
output = self.combine(score)
return score
if __name__ == "__main__":
x_test = torch.rand(256, 64, 128)
n_model = 128
n_head = 8
attention = MultiHead(n_model, n_head)
output = attention(x_test, x_test, x_test)
print(output, output.shape())
测试结果
在解释器中运行后,即可得到如下结果,输出的attention和输入q,k,v向量维度一致