一、torch.mul
该乘法可简单理解为矩阵各位相乘,一个常见的例子为向量点乘,源码定义为torch.mul(input,other,out=None)
。其中other可以为一个数也可以为一个张量,other为数即张量的数乘。
该函数可触发广播机制(broadcast)。
tensor1 = 2*torch.ones(1,4)
tensor2 = 3*torch.ones(4,1)
print(torch.mul(tensor1, tensor2))
#输出结果为:
tensor([[6., 6., 6., 6.],
[6., 6., 6., 6.],
[6., 6., 6., 6.],
[6., 6., 6., 6.]])
二、torch.mm
这是我们在线性代数课程中学习的矩阵乘法。该函数源码定义为torch.mm(input,mat2,out=None)
,参数与返回值均为tensor形式。
a=torch.ones(4,3)
b=2*torch.ones(3,2)
c=torch.empty(4,2)
torch.mm(a,b,out=c)
print(torch.mm(a,b))
print( c )
#输出结果为
tensor([[6., 6.],
[6., 6.],
[6., 6.],
[6., 6.]])
tensor([[6., 6.],
[6., 6.],
[6., 6.],
[6., 6.]])