文章目录
爱因斯坦求和约定
爱因斯坦求和约定(einsum)提供了一套既简洁又优雅的规则,可实现包括但不限于:向量内积,向量外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练运用 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。
自由索引(Free indices)和求和索引(Summation indices):
- 自由索引:出现在箭头右边的索引
- 求和索引:只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出
三条基本规则
- 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作
- 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和;
- 规则三,equation 箭头右边的索引顺序可以是任意的
使用方法
提取元素
提取矩阵对角线元素
import torch
a = torch.arange(9).reshape(3, 3)
x = torch.einsum('ii->i', [a])
print(a)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
print(x)
tensor([0, 4, 8])
转置
矩阵转置
import torch
a = torch.arange(6).reshape(2, 3)
x = torch.einsum('ij->ji', [a])
print(a)
tensor([[0, 1, 2],
[3, 4, 5]])
print(x)
tensor([[0, 3],
[1, 4],
[2, 5]])
permute 高维张量转置
import torch
a = torch.randn(2,3,5,7,9)
# i = 7, j = 9
x = torch.einsum('...ij->...ji', [a])
print(a.shape)
torch.Size([2, 3, 5, 7, 9])
print(x.shape)
torch.Size([2, 3, 5, 9, 7])
求和
矩阵求和
import torch
a = torch.arange(6).reshape(2, 3)
# i = 2, j = 3
x = torch.einsum('ij->', [a])
print(a)
tensor([[0, 1, 2],
[3, 4, 5]])
print(x)
tensor(15)
矩阵按行求和
import torch
a = torch.arange(6).reshape(2, 3)
# i = 2, j = 3
x = torch.einsum('ij->i', [a])
print(a)
tensor([[0, 1, 2],
[3, 4, 5]])
print(x)
tensor([ 3, 12])
矩阵按列求和
import torch
a = torch.arange(6).reshape(2, 3)
# i = 2, j = 3
x = torch.einsum('ij->j', [a])
print(a)
tensor([[0, 1, 2],
[3, 4, 5]])
print(x)
tensor([3, 5, 7])
乘积
矩阵向量乘法
import torch
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
# i = 2, k = 3
x = torch.einsum('ik,k->i', [a, b])
# 等价形式,可以省略箭头和输出
x2 = torch.einsum('ik,k', [a, b])
print(a)
tensor([[0, 1, 2],
[3, 4, 5]])
print(b)
tensor([0, 1, 2])
print(x)
tensor([ 5, 14])
print(x2)
tensor([ 5, 14])
矩阵乘法
import torch
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
# i = 2, k = 3, j = 5
x = torch.einsum('ik,kj->ij', [a, b])
# 等价形式,可以省略箭头和输出
x2 = torch.einsum('ik,kj', [a, b])
print(a)
tensor([[0, 1, 2],
[3, 4, 5]])
print(b)
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
print(x)
tensor([[ 25, 28, 31, 34, 37],
[ 70, 82, 94, 106, 118]])
print(x2)
tensor([[ 25, 28, 31, 34, 37],
[ 70, 82, 94, 106, 118]])
向量内积
import torch
a = torch.arange(3)
b = torch.arange(3, 6) # [3, 4, 5]
# i = 3
x = torch.einsum('i,i->', [a, b])
# 等价形式,可以省略箭头和输出
x2 = torch.einsum('i,i', [a, b])
print(a)
tensor([0, 1, 2])
print(b)
tensor([3, 4, 5])
print(x)
tensor(14)
print(x2)
tensor(14)
矩阵元素对应相乘并求和
import torch
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
# i = 2, j = 3
x = torch.einsum('ij,ij->', [a, b])
# 等价形式,可以省略箭头和输出
x2 = torch.einsum('ij,ij', [a, b])
print(a)
tensor([[0, 1, 2],
[3, 4, 5]])
print(b)
tensor([[ 6, 7, 8],
[ 9, 10, 11]])
print(x)
tensor(145)
print(x2)
tensor(145)
向量外积
import torch
a = torch.arange(3)
b = torch.arange(3,7) # [3, 4, 5, 6]
# i = 3, j = 4
x = torch.einsum('i,j->ij', [a, b])
# 等价形式,可以省略箭头和输出
x2 = torch.einsum('i,j', [a, b])
print(a)
tensor([0, 1, 2])
print(b)
tensor([3, 4, 5, 6])
print(x)
tensor([[ 0, 0, 0, 0],
[ 3, 4, 5, 6],
[ 6, 8, 10, 12]])
print(x2)
tensor([[ 0, 0, 0, 0],
[ 3, 4, 5, 6],
[ 6, 8, 10, 12]])
batch 矩阵乘法
import torch
a = torch.randn(2,3,5)
b = torch.randn(2,5,4)
# i = 2, j = 3, k = 5, l = 4
x = torch.einsum('ijk,ikl->ijl', [a, b])
print(a.shape)
torch.Size([2, 3, 5])
print(b.shape)
torch.Size([2, 5, 4])
print(x.shape)
torch.Size([2, 3, 4])
变换
张量收缩(tensor contraction)
import torch
a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
# p = 2, q = 3, r = 5, s = 7
# t = 11, u = 13, v = 17, r = 5
x = torch.einsum('pqrs,tuqvr->pstuv', [a, b])
print(a.shape)
torch.Size([2, 3, 5, 7])
print(b.shape)
torch.Size([11, 13, 3, 17, 5])
print(x.shape)
torch.Size([2, 7, 11, 13, 17])
二次变换(bilinear transformation)
import torch
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
# i = 2, k = 3, j = 5, l = 7
x = torch.einsum('ik,jkl,il->ij', [a, b, c])
print(a.shape)
torch.Size([2, 3])
print(b.shape)
torch.Size([5, 3, 7])
print(c.shape)
torch.Size([2, 7])
print(x.shape)
torch.Size([2, 5])
其他操作
- 向量操作:A、B均为向量
- 向量操作:A、B均为2D矩阵
注意: - einsum求和时不提升数据类型,如果使用的数据类型范围有限,可能会得到意外的错误:
a = np.ones(300, dtype=np.int8)
print(np.sum(a)) # correct result
print(np.einsum('i->', a)) # produces incorrect result
300
44
- einsum 在implicit mode可能不会按预期的顺序排列轴
M = np.arange(24).reshape(2,3,4)
print(np.einsum('kij', M).shape) # 不是预期
print(np.einsum('ijk->kij', M).shape) #符合预期
(3, 4, 2)
(4, 2, 3)
np.einsum(‘kij’, M) 实际上等价于 np.einsum(‘kij->ijk’, M),因为 implicit mode 下,einsum会认为根据输入标记,默认按照字母表顺序排序,作为输出标记。
参考:
https://zhuanlan.zhihu.com/p/101157166
https://zhuanlan.zhihu.com/p/361209187