引言
在线性代数里,经常会遇到各种计算操作符号,比如矩阵的点积,外积, H H H积, K K K积,转置等。爱因斯坦求和约定提供了一套简单优雅的规则可以实现以上操作,目的是省略掉求和公式中的求和号“+”。
定义(爱因斯坦求和约定):如果两个相同的指标出现在指标符号公式的同一项中,则表示对该指标遍历整个取值范围求和。
爱因斯坦求和约定具体的规则可以归结为如下几条
- 在同一项中,如果同一指标成对出现,就表示遍历其取值范围求和
- 公式中成对出现的指标叫做哑指标亦或哑标,表示哑标的小写字母可以用另一个小写字母替换,并且其取值范围不变
- 相同的指标求和指标,而其余的指标称为自由指标
- 当两个求和式相乘时, 两个求和式的哑标不能使用相同的字母
当前爱因斯坦求和约定已经在numpy,keras,tensorflow和pytorch等库中进行了实现。本文主要介绍爱因斯坦求和约定在pytorch中的实现。
torch.einsum用法介绍
pytorch中实现爱因斯坦求和约定的函数是torch.einsum,该函数的功能和用法细节如下所示
torch.einsum(equation, *operands) ⟶ \longrightarrow ⟶ Tensor
- equation (string):表示爱因斯坦求和下标公式
- operands (List[Tensor]):计算爱因斯坦求和的张量列表
注意事项:
- torch.einsum函数的第一个参数表示输入和输出张量的维度,其中equation中的箭头左边表示输入张量,逗号分割开每个输入张量,箭头右边则表示输出张量。表示维度的字符只能是 26 26 26个大小写英文字母’a’—‘z’或’A’—‘Z’
- torch.einsum函数的第二个参数表示实际输入的张量列表,其数量要与第一个参数equation的输入数量对应,并且其字符数量要与张量的真实维度对应,例如 i j , j k → i k \mathrm{ij,jk}\rightarrow \mathrm{ik} ij,jk→ik表示输入和输出张量都是 2 2 2维的
- equation可以不写包括箭头在内的右边部分,则此时输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,例如的矩阵乘法 i j , j k → i k \mathrm{ij,jk}\rightarrow \mathrm{ik} ij,jk→ik也可以简化为 i j , j k \mathrm{ij,jk} ij,jk,根据默认规则,输出就是 i k \mathrm{ik} ik
- equation中支持“ . . . ... ...”省略号,用于表示不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写成“ . . . i j → . . . j i \mathrm{...ij \rightarrow...ji} ...ij→...ji”
代码实例
torch.einsum求矩阵的迹
>>> import torch
>>> torch.einsum('jj', torch.randn(4, 4))
tensor(-2.5026)
torch.einsum提取矩阵对角线元素为向量
>>> import torch
>>> torch.einsum('jj->j', torch.randn(4, 4))
tensor([-0.7857, -0.1267, 0.2240, 0.7167])
torch.einsum求矩阵的外积即矩阵的乘积
>>> import torch
>>> x = torch.randn(5)
>>> y = torch.randn(4)
>>> torch.einsum('j,k->jk', [x,y])
tensor([[ 0.2342, 0.4103, 0.0309, -0.3783],
[-0.4378, -0.7671, -0.0577, 0.7073],
[ 0.3775, 0.6614, 0.0498, -0.6099],
[ 0.0734, 0.1286, 0.0097, -0.1186],
[ 0.4345, 0.7613, 0.0573, -0.7019]])
torch.einsum求批矩阵乘法
>>> import torch
>>> x = torch.randn(5)
>>> y = torch.randn(4)
>>> torch.einsum('j,k->jk', [x,y])
tensor([[ 0.2342, 0.4103, 0.0309, -0.3783],
[-0.4378, -0.7671, -0.0577, 0.7073],
[ 0.3775, 0.6614, 0.0498, -0.6099],
[ 0.0734, 0.1286, 0.0097, -0.1186],
[ 0.4345, 0.7613, 0.0573, -0.7019]])
>>> import torch
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', [As, Bs])
tensor([[[ 0.6446, 0.7576, -0.7701, 2.5654],
[-2.1118, -0.1421, -0.1514, -1.2957]],
[[-0.7380, -1.5484, 3.4401, -0.6755],
[-0.7703, 2.2886, -3.7525, 0.6178]],
[[-0.9564, 0.3753, -2.1906, -1.4844],
[-1.0078, -0.0325, -1.2611, -1.5359]]])
torch.einsum求批矩阵转置
>>> import torch
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij-> ...ji', A).shape
torch.Size([2, 3, 5, 4])