目录
1--torch.einsum()函数
1-1--常用用法
import torch
B = 2
C = 3
T = 4
N = 5
a = torch.randint(0, 5, [B, C, T, N], dtype = float) # 从0-5中生成维度为[B, C, T, N]的随机整数数据
b = torch.randint(0, 5, [B, C, N, N], dtype = float)
c = torch.einsum('bctn,bcmn->bctm', a, b)
print(a, a.shape)
print(b, b.shape)
print(c, c.shape)
tensor([[[[3., 3., 2., 0., 4.],
[3., 3., 0., 1., 0.],
[3., 0., 4., 3., 4.],
[3., 3., 2., 3., 2.]],
[[1., 1., 3., 0., 4.],
[0., 0., 1., 3., 0.],
[0., 4., 3., 2., 4.],
[0., 0., 3., 2., 3.]],
[[4., 3., 3., 1., 0.],
[2., 2., 0., 3., 0.],
[3., 1., 2., 3., 4.],
[2., 3., 1., 2., 2.]]],
[[[0., 3., 0., 4., 4.],
[2., 1., 1., 0., 1.],
[4., 2., 4., 3., 4.],
[3., 0., 4., 1., 0.]],
[[3., 4., 0., 1., 4.],
[2., 3., 3., 0., 1.],
[4., 2., 2., 0., 3.],
[1., 3., 4., 1., 4.]],
[[0., 2., 0., 1., 4.],
[4., 2., 3., 0., 3.],
[4., 3., 1., 0., 2.],
[4., 4., 1., 3., 4.]]]], dtype=torch.float64) torch.Size([2, 3, 4, 5])
tensor([[[[1., 3., 1., 2., 0.],
[2., 4., 0., 1., 1.],
[2., 1., 3., 0., 2.],
[0., 4., 3., 0., 0.],
[1., 4., 0., 4., 0.]],
[[4., 1., 3., 1., 4.],
[0., 3., 0., 2., 0.],
[3., 0., 1., 2., 0.],
[0., 4., 4., 2., 4.],
[1., 2., 4., 1., 2.]],
[[2., 0., 3., 0., 3.],
[4., 2., 4., 0., 3.],
[2., 4., 4., 0., 2.],
[3., 1., 2., 2., 0.],
[4., 1., 0., 4., 3.]]],
[[[0., 1., 0., 3., 4.],
[0., 2., 2., 0., 0.],
[3., 2., 3., 4., 4.],
[3., 3., 3., 4., 4.],
[0., 1., 2., 2., 2.]],
[[4., 2., 2., 4., 3.],
[3., 4., 3., 2., 2.],
[4., 1., 4., 4., 4.],
[4., 0., 1., 1., 1.],
[4., 0., 4., 1., 3.]],
[[4., 2., 0., 0., 2.],
[4., 1., 1., 2., 1.],
[1., 4., 4., 0., 1.],
[1., 0., 3., 3., 0.],
[1., 4., 4., 1., 1.]]]], dtype=torch.float64) torch.Size([2, 3, 5, 5])
tensor([[[[14., 22., 23., 18., 15.],
[14., 19., 9., 12., 19.],
[13., 13., 26., 12., 15.],
[20., 23., 19., 18., 27.]],
[[30., 3., 6., 32., 23.],
[ 6., 6., 7., 10., 7.],
[31., 16., 7., 48., 30.],
[23., 4., 7., 28., 20.]],
[[17., 34., 32., 23., 23.],
[ 4., 12., 12., 14., 22.],
[24., 34., 26., 20., 37.],
[13., 24., 24., 15., 25.]]],
[[[31., 6., 38., 41., 19.],
[ 5., 4., 15., 16., 5.],
[27., 12., 56., 58., 24.],
[ 3., 8., 25., 25., 10.]],
[[36., 35., 36., 17., 25.],
[23., 29., 27., 12., 23.],
[33., 32., 38., 21., 33.],
[34., 37., 43., 13., 33.]],
[[12., 8., 12., 3., 13.],
[26., 24., 27., 13., 27.],
[26., 22., 22., 7., 22.],
[32., 31., 28., 16., 31.]]]], dtype=torch.float64) torch.Size([2, 3, 4, 5])
1-2--运算规则
参数介绍:torch.einsum()函数有两个输入参数:
①第一个参数是equation,即代码中的字符串'bctn,bcmn->bctm',表示输入输出的维度大小;
②第二个参数是实际输入的tensor列表,即代码中的张量a和张量b。
索引介绍:torch.einsum()函数有两种索引参数:
①自由索引:出现在equation箭头右边的索引(字母),比如代码中的'bctm';
②求和索引:只出现在equation箭头左边的索引(字母),比如代码中的索引'n'。求和索引的作用是中间计算结果必须先在这个索引维度上进行求和,再输出求和后的结果。
基本规则介绍:torch.einsum()函数有三条基本规则:
①规则一:在equation箭头左侧中不同输入之间重复出现的索引,其作用是把输入张量依据该维度做乘法操作(即相应维度的对应数据相乘)。如代码的equation为'bctn,bcmn->bctm',其箭头左侧中不同输入之间重复出现的索引为:'b','c','n',表示沿着这三个维度进行乘法运算。
②规则二:仅出现在equation箭头左侧的索引,表示中间计算结果需要在这个维度上进行求和运算(即上文中介绍的求和索引)。如代码中的索引'n'只出现在箭头左侧,则沿着该维度做乘法操作后需要将结果进行求和再输出(可以理解成矩阵相乘,即对应元素相乘再相加)。
③规则三:equation箭头右侧的索引顺序可以是任意的,即进行维度交换。
易懂例子:
import torch
x = torch.randint(0, 5, [2, 3], dtype = float)
y = torch.randint(0, 5, [3, 4], dtype = float)
z1 = torch.einsum('ij,jk->ik', x, y)
z2 = torch.einsum('ij,jk->ki', x, y)
print(x, x.shape)
print(y, y.shape)
print(z1, z1.shape)
print(z2, z2.shape)
tensor([[3., 1., 4.],
[2., 3., 2.]], dtype=torch.float64) torch.Size([2, 3])
tensor([[0., 3., 1., 4.],
[4., 0., 2., 3.],
[0., 0., 2., 2.]], dtype=torch.float64) torch.Size([3, 4])
tensor([[ 4., 9., 13., 23.],
[12., 6., 12., 21.]], dtype=torch.float64) torch.Size([2, 4])
tensor([[ 4., 12.],
[ 9., 6.],
[13., 12.],
[23., 21.]], dtype=torch.float64) torch.Size([4, 2])
实例解答(实例等价于矩阵乘法):
①自由索引为'i'和'k',其顺序可以任意调换,相当于维度变换。(实例中为转置操作)
②索引'j'适用于规则一,两个元素相乘时其关于维度j的编号必须相同。如样本x中[3, 1, 4]的[3](关于维度j为编号为0),只能和样本y中[4, 3, 2]的[4]相乘(关于维度j的编号也为0)。(样本x第一列的数据元素与样本y第一行的数据元素,j编号相同,均为0)
③索引'j'为求和索引,中间计算结果需要求和。