einsum(equation, *tensors)函数对输入的一个或多个张量(*tensors)作定义为equation的运算
通常来说,equation由输入的张量(*tensors)的shape来定义的,类似于einops模块中的rearrange, reduce, repeat但规则不同:
1. 一般来说,eqution的形式是'i j -> j i'(矩阵转置例子,这里的i, j对应的是张量shape的对应维度) # 如shape:(2, 3);可知i=2, j=3.
2. 当箭头'->'右侧为空,默认为箭头左侧只出现一次的值,如'ij,jk -> '等价于'ij, jk -> ik'(矩阵乘法)
3. 对于箭头左侧重复出现的值(指向对应张量shape的某个维度),代表沿着该维度作乘法操作,如'ij, jk -> ik'(沿着'j'维度作乘法,矩阵乘法)
4. 在规则2的基础上,我们可以任意调整箭头右侧维度的顺序进而调整输出,如'ij, jk -> ki'(就是先做矩阵乘法,ik再转置为ki)
5. 若对多个张量进行操作,需在箭头左侧用逗号隔开表示指向不同张量的shape
import einsum
import torch
x = [
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
],
[
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]
]
]
# 将x转化为张量,数据类型为浮点数
x = torch.tensor(x).float()
print(x.shape) # torch.Size([2, 3, 4])
# einsum()函数对张量进行降维
x1_1 = torch.einsum('b h w -> h w', x)
# 这里等效于reduce('b h w -> h w', 'max')函数,默认采用最大池化
print(x1_1)
print(x1_1.shape)
"""
tensor([[14., 16., 18., 20.],
[22., 24., 26., 28.],
[30., 32., 34., 36.]])
torch.Size([3, 4])
"""
# einsum()函数交换张量shape的特定维度(支持一次性交换多个维度)
x1_2 = torch.einsum('i j k -> k i j', x)
print(x1_2)
print(x1_2.shape)
"""
tensor([[[ 1., 5., 9.],
[13., 17., 21.]],
[[ 2., 6., 10.],
[14., 18., 22.]],
[[ 3., 7., 11.],
[15., 19., 23.]],
[[ 4., 8., 12.],
[16., 20., 24.]]])
torch.Size([4, 2, 3])
"""
# torch.einsum()不支持对维度通过数值乘除放缩,
# 因为equation中' -> '箭头两侧元素必须都是在{a/A-z/Z}即所有大小写字母组成的集合中
x1_3 = torch.einsum('...h w -> ...w h', x) # 交换张量x.shape的后两个维度
print(x1_3)
print(x1_3.shape)
"""
tensor([[[ 1., 5., 9.],
[ 2., 6., 10.],
[ 3., 7., 11.],
[ 4., 8., 12.]],
[[13., 17., 21.],
[14., 18., 22.],
[15., 19., 23.],
[16., 20., 24.]]])
torch.Size([2, 4, 3])
"""
# torch.einsum()函数还支持矩阵乘法(本质是broadcasting广播规则),矩阵trace等运算
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(a.shape) # torch.Size([2, 3])
b = torch.tensor([[7, 8], [9, 10], [11, 12]])
print(b.shape) # torch.Size([3, 2])
x2_1 = torch.einsum('ij, jk -> ik', a, b) # 矩阵乘法
print(x2_1)
print(x2_1.shape)
"""
tensor([[ 58, 64],
[139, 154]])
torch.Size([2, 2])
"""
c = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
x2_2 = torch.einsum('ii -> ', c) # 矩阵trace的运算
print(x2_2) # tensor(15);就是一个数
print(x2_2.shape) # torch.Size([]),有点odd不太懂
print(x2_2.shape is None) # False
# torch.einsum()实现两个向量直接的点积
d = torch.tensor([1, 2, 3])
e = torch.tensor([4, 5, 6])
x3_1 = torch.einsum('i, i -> ', d, e) # 向量点积
print(x3_1) # tensor(32)
print(x3_1.shape) # torch.Size([])