torch.einsum
>>> a = torch.arange(60.).reshape(5,3,4)
>>> b = torch.arange(24.).reshape(3,4,2)
>>> o = torch.einsum('fnd,ndh->fh', a, b)
>>> o
tensor([[1012., 1078.],
[2596., 2806.],
[4180., 4534.],
[5764., 6262.],
[7348., 7990.]])
>>>torch.matmul(a[0,:,:].flatten(),b[:,:,0].flatten())
tensor(1012.)
// the first element of the result of einsum
>>> torch.matmul(a[1,:,:].flatten(), b[:,:,0].flatten())
tensor(2596.)