torch.einsum 解析

请关注我的微信公众号,谢谢啦
请关注我的微信公众号,谢谢啦
最近在看VM-UNet[1]的代码实现,其中高频出现torch.einsum,并且与此前我见过的使用方式不同,因此记录一下用法。

torch.einsum是pytorch上的一个强大的函数,用于矩阵相关的计算,注意,这里没有限定为矩阵乘法。torch.einsum基于爱因斯坦求和约定执行张量操作,能够用简洁的表达式实现复杂的多维数组操作,从而避免繁琐的张量操作组合(如reshape、permute、bmm等),减少错误率。需要说明的是,尽管einsum函数内部进行了大量计算优化,但其主要优势在于表达式简洁,如果与单步reshape等pytorch实现的矩阵运算操作相比,其运算速度与内存占用不一定占优势。

下面介绍爱因斯坦求和约定。爱因斯坦求和约定是一种简洁的张量运算符号系统。其基本思想是通过省略某些重复出现的索引,从而自动隐式进行求和操作。下面通过例子说明具体表达方式。
  1. 矩阵乘法:‘ij,jk->ik’ 表示形状为(i,j)与形状为(j,k)的矩阵进行矩阵乘法,得到新矩阵形状为(i,k)。这也是torch.einsum最常规的用法。

  2. 维度调换:'ij->ji’表示形状为(i,j)的矩阵维度调换成为形状为(j,i)的矩阵。

    上述操作限定在二维矩阵,且能够仅通过pytorch的一个函数实现,如bmm、permute。但我们在深度学习模型实现中常常遇到高维数据,例如对二维图片进行分块等操作,此时,往往需要组合使用flatten、cat、stack、reshape、permute、bmm等操作来实现特定的矩阵运算,这不仅增加了代码实现的难度,也极大降低了代码开发效率与执行效率。这时候,einsum的优势就十分明显,往往可以通过一行代码解决战斗。例如VM-UNet中调用的SS2D函数中有一段代码如下:

x_dbl = torch.einsum(“b k d l, k c d -> b k c l”, xs.view(B, K, -1, L), self.x_proj_weight)
我们仅看 “b k d l, k c d -> b k c l” 部分。第一个矩阵形状为(b,k,d,l),第二个矩阵形状为(k,c,d),输出形状为(b,k,c,l)。根据上面两个例子,我们发现,完全看不懂这个怎么来的,下面我们补充两个限定条件。

einsum可以看做是针对张量中每个特定元素进行操作,因此不必过多考虑矩阵运算中的维度匹配问题。

运算表达式 “b k d l, k c d -> b k c l” 中,k 在输入表达式 b k d l 和 k c d 中都出现,且出现在输出表达式 b k c l 中,该维度仅进行点乘,无求和运算。 b 仅在第一个输入表达式出现,l 仅在第一个输入表达式出现,c 仅在第二个输入表达式出现,因此 blc 均会出现在最终表达式,因为没有向量求和消掉维度。而 d 维度出现在两个输入表达式中,且没有出现在输出表达式中,说明其通过两个张量的对应维度点乘求和消掉了。

下面我们思考一下,如果要对上述两个张量采用矩阵运算,应该包含哪些步骤。

我们将张量运算设为F,两个输入张量分别为A和B,输出张量为C。则F(A, B)=C。

首先将B形状扩充为(b,k,c,d),然后调换维度到(b,k,d,c),再将A(b,k,d,l)转换为A(bkd,l),B转化为B(bkd,c),然后进行外积,得到(bkd,l,c)。然后展开 (b, k, d, l, c)。将d维度sum求和去掉得到C(b,k,l,c),最后再调换维度得到C(b, k, c, l)。

可以看到,短短一行表达式,就包含了如此多的矩阵操作,可见einsum是我们进行复杂张量运算的必备神器,大家赶快学起来吧!

下一次将对市面上已有的典型Mamba分割模型进行分析,敬请期待。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

鹤城北斗

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值