前言
我的导师说:“Attention Is All You Need”这篇论文写得很拉,要想弄懂Attention还得是这个博客,依据这个博客,我们可以手搓Self-Attention了!
步骤
1. 根据输入得到相应的Q、K、V矩阵
2. 一系列运算
2.1 得到Q和K的相关性分数α(Q和K做内积)
2.2 得到Q和K的相关性分数α'(先除以根号dk进行放缩,然后softmax归一化)
2.3 提取出V的信息(V与α'相乘)
实现
1. 代码
在我的实例中,数据维度是(B,C,L),B表示batch size,C表示channel,L表示序列长度。
d_m就是论文中的,d_k就是论文中的。
class Attention(nn.Module):
def __init__(self, d_m, d_k) -> None:
super(Attention, self).__init__()
self.d_m = d_m
self.d_k = d_k
self.wq = nn.Linear(in_features=self.d_m, out_features=self.d_k)
self.wk = nn.Linear(in_features=self.d_m, out_features=self.d_k)
self.wv = nn.Linear(in_features=self.d_m, out_features=self.d_k)
def forward(self, x):
q, k, v = self.wq(x), self.wk(x), self.wv(x)
score = pt.einsum('nci, ncj -> nc', q, k)
score /= pt.sqrt(pt.tensor(self.d_k))
score = nn.functional.softmax(score, -1)
# 因为做内积少了一维,所以需要增加第三个维度再做乘积
out = v * score[:,:,None]
return out
2. einsum
这里用到了爱因斯坦求和,这个函数很强大,可以进行各种矩阵运算。这里简单说明一下。
2.1 参数一:表达式
主旨就是通过矩阵的维度变化表示矩阵做的运算,字母是随意的,只要能代表相应维度即可。
在我的代码中,'nci, ncj -> nc'就表示对两个矩阵的第三个维度做内积,对应元素相乘再相加形成一个标量,所以减少了第三维,之后需要扩充出第三维。
2.2 参数二:参与运算的矩阵
在我的代码中,参与运算的矩阵就是q和k。
2.3 官网示例
''' 求矩阵的迹 '''
torch.einsum('ii', torch.randn(4, 4))
''' tensor(-1.2104) '''
''' 提取对角线元素 '''
torch.einsum('ii->i', torch.randn(4, 4))
''' tensor([-0.1034, 0.7952, -0.2433, 0.4545]) '''
''' 外积 '''
x = torch.randn(5)
y = torch.randn(4)
torch.einsum('i,j->ij', x, y)
'''
tensor([[ 0.1156, -0.2897, -0.3918, 0.4963],
[-0.3744, 0.9381, 1.2685, -1.6070],
[ 0.7208, -1.8058, -2.4419, 3.0936],
[ 0.1713, -0.4291, -0.5802, 0.7350],
[ 0.5704, -1.4290, -1.9323, 2.4480]])
'''
''' 批量矩阵相乘 '''
As = torch.randn(3, 2, 5)
Bs = torch.randn(3, 5, 4)
torch.einsum('bij,bjk->bik', As, Bs)
'''
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]],
[[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]],
[[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
'''
''' 具有子列表格式和省略号 '''
torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
'''
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]],
[[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]],
[[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
'''
''' 批量维度变换 '''
A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
''' torch.Size([2, 3, 5, 4]) '''
''' 与torch.nn.functional.bilinear等价 '''
A = torch.randn(3, 5, 4)
l = torch.randn(2, 5)
r = torch.randn(2, 4)
torch.einsum('bn,anm,bm->ba', l, A, r)
'''
tensor([[-0.3430, -5.2405, 0.4494],
[ 0.3311, 5.5201, -3.0356]])
'''
实验效果
所以我手搓的Self-Attention效果如何呢?
1. 准确率和运行速度:
左边是直接调用nn.MultiheadAttention,heads设为1,右边是我手搓的:
可以看到虽然在一开始手搓的准确率不如直接调用,但是在后来而这准确率相差不是很大;直接调用耗时3h+,手搓耗时2h20min,速度有优势。
2. 内存占用
在我的项目中,使用手搓的要比直接调用少占用6个G的显存!当然因项目而异了,不过也可以看出pytorch提供的接口里面参数量很大。
结尾
有问题欢迎在评论区讨论!