torch.mul(a, b)
是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵torch.mm(a, b)
是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵
import torch
a = torch.rand(1, 2)
b = torch.rand(1, 2)
c = torch.rand(2, 3)
print(torch.mul(a, b)) # 返回 1*2 的tensor
print(torch.mm(a, c)) # 返回 1*3 的tensor
print(torch.mul(a, c)) # 由于a、b维度不同,报错