torch.mul(a, b)
是矩阵a和b对应位相乘,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵torch.mm(a, b)
是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵
import torch
a = torch.rand(1, 2)
b = torch.rand(1, 2)
print(torch.mul(a, b)) # 返回 1*2 的tensor
print(torch.mm(a, c)) # 返回 1*3 的tensor