pytorch中einsum详解

引言

在线性代数里,经常会遇到各种计算操作符号,比如矩阵的点积,外积, H H H积, K K K积,转置等。爱因斯坦求和约定提供了一套简单优雅的规则可以实现以上操作,目的是省略掉求和公式中的求和号“+”。

定义(爱因斯坦求和约定):如果两个相同的指标出现在指标符号公式的同一项中,则表示对该指标遍历整个取值范围求和。

爱因斯坦求和约定具体的规则可以归结为如下几条

  • 在同一项中,如果同一指标成对出现,就表示遍历其取值范围求和
  • 公式中成对出现的指标叫做哑指标亦或哑标,表示哑标的小写字母可以用另一个小写字母替换,并且其取值范围不变
  • 相同的指标求和指标,而其余的指标称为自由指标
  • 当两个求和式相乘时, 两个求和式的哑标不能使用相同的字母

当前爱因斯坦求和约定已经在numpykerastensorflowpytorch等库中进行了实现。本文主要介绍爱因斯坦求和约定在pytorch中的实现。

torch.einsum用法介绍

pytorch中实现爱因斯坦求和约定的函数是torch.einsum,该函数的功能和用法细节如下所示

torch.einsum(equation, *operands) ⟶ \longrightarrow Tensor

  • equation (string):表示爱因斯坦求和下标公式
  • operands (List[Tensor]):计算爱因斯坦求和的张量列表

注意事项:

  • torch.einsum函数的第一个参数表示输入和输出张量的维度,其中equation中的箭头左边表示输入张量,逗号分割开每个输入张量,箭头右边则表示输出张量。表示维度的字符只能是 26 26 26个大小写英文字母’a’—‘z’’A’—‘Z’
  • torch.einsum函数的第二个参数表示实际输入的张量列表,其数量要与第一个参数equation的输入数量对应,并且其字符数量要与张量的真实维度对应,例如 i j , j k → i k \mathrm{ij,jk}\rightarrow \mathrm{ik} ij,jkik表示输入和输出张量都是 2 2 2维的
  • equation可以不写包括箭头在内的右边部分,则此时输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,例如的矩阵乘法 i j , j k → i k \mathrm{ij,jk}\rightarrow \mathrm{ik} ij,jkik也可以简化为 i j , j k \mathrm{ij,jk} ij,jk,根据默认规则,输出就是 i k \mathrm{ik} ik
  • equation中支持“ . . . ... ...”省略号,用于表示不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写成“ . . . i j → . . . j i \mathrm{...ij \rightarrow...ji} ...ij...ji

代码实例

torch.einsum求矩阵的迹

>>> import torch
>>> torch.einsum('jj', torch.randn(4, 4))
tensor(-2.5026)

torch.einsum提取矩阵对角线元素为向量

>>> import torch
>>> torch.einsum('jj->j', torch.randn(4, 4))
tensor([-0.7857, -0.1267,  0.2240,  0.7167])

torch.einsum求矩阵的外积即矩阵的乘积

>>> import torch
>>> x = torch.randn(5)
>>> y = torch.randn(4)
>>> torch.einsum('j,k->jk', [x,y])
tensor([[ 0.2342,  0.4103,  0.0309, -0.3783],
        [-0.4378, -0.7671, -0.0577,  0.7073],
        [ 0.3775,  0.6614,  0.0498, -0.6099],
        [ 0.0734,  0.1286,  0.0097, -0.1186],
        [ 0.4345,  0.7613,  0.0573, -0.7019]])

torch.einsum求批矩阵乘法

>>> import torch
>>> x = torch.randn(5)
>>> y = torch.randn(4)
>>> torch.einsum('j,k->jk', [x,y])
tensor([[ 0.2342,  0.4103,  0.0309, -0.3783],
        [-0.4378, -0.7671, -0.0577,  0.7073],
        [ 0.3775,  0.6614,  0.0498, -0.6099],
        [ 0.0734,  0.1286,  0.0097, -0.1186],
        [ 0.4345,  0.7613,  0.0573, -0.7019]])
>>> import torch
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', [As, Bs])
tensor([[[ 0.6446,  0.7576, -0.7701,  2.5654],
         [-2.1118, -0.1421, -0.1514, -1.2957]],
        [[-0.7380, -1.5484,  3.4401, -0.6755],
         [-0.7703,  2.2886, -3.7525,  0.6178]],
        [[-0.9564,  0.3753, -2.1906, -1.4844],
         [-1.0078, -0.0325, -1.2611, -1.5359]]])

torch.einsum求批矩阵转置

>>> import torch
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij-> ...ji', A).shape
torch.Size([2, 3, 5, 4])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

道2024

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值