a = torch.tensor([[[1,2,3,4],[5,6,7,8]]])
a.shape == torch.Size([1,2,4])
b = torch.tensor([[[1,2,3,5],[6,1,2,3],[1,3,2,1]]])
b.shape == torch.Size([1,3,4])
c = torch.einsum('ijk,ilk->ijl', a,b)
print(c)
tensor([[[34, 26, 17],
[78, 74, 45]]])
根据上述例子粗暴的解释torch.einsum('ijk,ilk->ijl', a,b):
对于两个shape为 [1,2,4] ,[1,3,4]的tensor
首先给需要计算的a和b的每个维度作标记,比如输入的a (1=i,2=j,4=k)和b (1=i,3=l,4=k), 输出c(1=i,2=j,3=l)
可以看到ab第0维i相同,所以在第1维第2维做运算。
看成没有第0维的二维矩阵乘法(a乘b的转置), 维度变化为:(2,4) *(4,3)-> (2,3)
可验证:a的第一行乘以 b的转置的第一列=34 ,[1,2,3,4] *[1,2,3,5] = 1+2*2+3*3+4*5= 34