import torch
a = torch.rand(size=(5, 2)) * 1e4
b = torch.rand(size=(2, 3)) * 1e4
ab = torch.einsum('ik,kj->ij', a, b)
ab1 = torch.matmul(a, b)
print(ab.shape)
print(ab - ab1)
a = torch.rand(size=(1, 5, 2)) * 1e4
b = torch.rand(size=(1, 2, 3)) * 1e4
ab = torch.einsum('bik,bkj->bij', a, b)
ab1 = torch.matmul(a, b)
print(ab.shape)
print(ab - ab1)
采用einsum实现矩阵乘法,输出结果(torch1.8.1)
torch.Size([5, 3])
tensor([[ 0., 0., 0.],
[-1., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., -2., 0.]])
torch.Size([1, 5, 3])
tensor([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]])
很明显第一段代码输出错误。当矩阵变大时,这种错误会更大。相同问题在tensorflow下并不会出现。