Self-Attention(einsum)

前言

我的导师说:“Attention Is All You Need”这篇论文写得很拉,要想弄懂Attention还得是这个博客,依据这个博客,我们可以手搓Self-Attention了!

步骤

1. 根据输入得到相应的Q、K、V矩阵

2. 一系列运算

2.1 得到Q和K的相关性分数α(Q和K做内积)

2.2 得到Q和K的相关性分数α'(先除以根号dk进行放缩,然后softmax归一化)

2.3 提取出V的信息(V与α'相乘)

实现

1. 代码

在我的实例中,数据维度是(B,C,L),B表示batch size,C表示channel,L表示序列长度。

d_m就是论文中的d_{model},d_k就是论文中的d_k

class Attention(nn.Module):
    def __init__(self, d_m, d_k) -> None:
        super(Attention, self).__init__()
        self.d_m = d_m
        self.d_k = d_k
        self.wq = nn.Linear(in_features=self.d_m, out_features=self.d_k)
        self.wk = nn.Linear(in_features=self.d_m, out_features=self.d_k)
        self.wv = nn.Linear(in_features=self.d_m, out_features=self.d_k)

    def forward(self, x):
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        score = pt.einsum('nci, ncj -> nc', q, k)
        score /= pt.sqrt(pt.tensor(self.d_k))
        score = nn.functional.softmax(score, -1)
        # 因为做内积少了一维,所以需要增加第三个维度再做乘积
        out = v * score[:,:,None]
        return out

2. einsum

这里用到了爱因斯坦求和,这个函数很强大,可以进行各种矩阵运算。这里简单说明一下。

2.1 参数一:表达式

主旨就是通过矩阵的维度变化表示矩阵做的运算,字母是随意的,只要能代表相应维度即可。

在我的代码中,'nci, ncj -> nc'就表示对两个矩阵的第三个维度做内积,对应元素相乘再相加形成一个标量,所以减少了第三维,之后需要扩充出第三维。

2.2 参数二:参与运算的矩阵

在我的代码中,参与运算的矩阵就是q和k。

2.3 官网示例

''' 求矩阵的迹 '''
torch.einsum('ii', torch.randn(4, 4))
''' tensor(-1.2104) '''

''' 提取对角线元素 '''
torch.einsum('ii->i', torch.randn(4, 4))
''' tensor([-0.1034,  0.7952, -0.2433,  0.4545]) '''

''' 外积 '''
x = torch.randn(5)
y = torch.randn(4)
torch.einsum('i,j->ij', x, y)
'''
tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
        [-0.3744,  0.9381,  1.2685, -1.6070],
        [ 0.7208, -1.8058, -2.4419,  3.0936],
        [ 0.1713, -0.4291, -0.5802,  0.7350],
        [ 0.5704, -1.4290, -1.9323,  2.4480]])
'''

''' 批量矩阵相乘 '''
As = torch.randn(3, 2, 5)
Bs = torch.randn(3, 5, 4)
torch.einsum('bij,bjk->bik', As, Bs)
''' 
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])
'''

''' 具有子列表格式和省略号 '''
torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
'''
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])
'''

''' 批量维度变换 '''
A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
''' torch.Size([2, 3, 5, 4]) '''

''' 与torch.nn.functional.bilinear等价 '''
A = torch.randn(3, 5, 4)
l = torch.randn(2, 5)
r = torch.randn(2, 4)
torch.einsum('bn,anm,bm->ba', l, A, r)
'''
tensor([[-0.3430, -5.2405,  0.4494],
        [ 0.3311,  5.5201, -3.0356]])
'''

实验效果

所以我手搓的Self-Attention效果如何呢?

1. 准确率和运行速度:

左边是直接调用nn.MultiheadAttention,heads设为1,右边是我手搓的:

 可以看到虽然在一开始手搓的准确率不如直接调用,但是在后来而这准确率相差不是很大;直接调用耗时3h+,手搓耗时2h20min,速度有优势。

2. 内存占用

在我的项目中,使用手搓的要比直接调用少占用6个G的显存!当然因项目而异了,不过也可以看出pytorch提供的接口里面参数量很大。

结尾

有问题欢迎在评论区讨论!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Burger~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值