torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作
torch.mm(a, b)
是矩阵a和b矩阵相乘,比如a的维度是(m, x),b的维度是(x, n),返回的就是(m, n)的矩阵
相同:都可以来做矩阵相乘:
a = torch.randn(2, 3)
b = torch.randn(3, 2)
print(torch.mm(a, b))
print(torch.matmul(a, b))
区别:
matmul支持向量相乘,mm不支持。
import torch
x = torch.rand(2)
y = torch.rand(2)
print(torch.matmul(x, y))
print(torch.mm(x, y)) #报错