对于论文中Pytorch代码爱因斯坦求和einsum的理解

本文详细介绍了PyTorch中的torch.einsum函数,它用于高效地进行张量运算,特别是对相同下标的求和。通过一个具体的例子,展示了如何使用einsum进行四维和三维张量的运算,得到四维结果。同时,通过for循环的方式手动实现了相同的运算过程,验证了einsum的正确性。这对于理解复杂的张量运算和代码实现非常有帮助。
摘要由CSDN通过智能技术生成

论文代码中经常会用到torch.einsum(),其本质是对相同下标求和
举例

torch.einsum('nctv, vtq -> ncqv', (x,y))

’ -> '左边是两个或多个元素原本的维度下标,右边是要得到的维度下标
理解这个公式最好的办法就是将其展开
这个例子是论文中的真实开源代码,求和的两边一个有四个维度,一个三个维度,最终得到四个维度的结果。

  • 使用zeros先创建结果维度的矩阵
  • 按照结果的维度按照外层循环展开,如最终结果有4个维度就创建4层循环
  • 内层循环即维度中消失的下标,在例子中消失的是t,即对t进行循环,在这之前创建temp准备记录求和结果
  • 内层循环求和,记录结果,将值赋给最终结果
import torch
x = torch.randn(32,3,10,22) # shape: n,c,t,v
y = torch.randn(22,10,10) # shape: v,t,q
res_1 = torch.einsum('nctv, vtq -> ncqv', (x,y))
print('einsum求和结果:')
print(res_1[1,1,1]) #维度太高写不下,就取最后一个维度看一看
print('-'*40)

res_2 = torch.zeros(32,3,10,22) #先将res_2维度设定为n,c,q,v
for n in range(32): #按照n,c,q,v展开外层循环
    for c in range(3):
        for q in range(10):
            for v in range(22):
                temp = 0 #准备记录求和结果
                for t in range(10): #内层循环,这个是消失的下标
                    temp += x[n,c,t,v] * y[v,t,q]
                res_2[n,c,q,v] = temp
print('for循环求和结果:')
print(res_2[1,1,1])

最后结果:
在这里插入图片描述
完全一致
所以当你看不懂论文中的爱因斯坦求和式子时,不妨用for循环将其展开来理解试试

参考:
https://blog.csdn.net/ashome123/article/details/117110042

  • 8
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

锌a

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值