关于pytorch中部分矩阵乘法的总结(torch.mm,torch.mul,torch.matmul)

一、torch.mul

该乘法可简单理解为矩阵各位相乘,一个常见的例子为向量点乘,源码定义为torch.mul(input,other,out=None)。其中other可以为一个数也可以为一个张量,other为数即张量的数乘。
该函数可触发广播机制(broadcast)。

tensor1 = 2*torch.ones(1,4)
tensor2 = 3*torch.ones(4,1)
print(torch.mul(tensor1, tensor2))
#输出结果为:
tensor([[6., 6., 6., 6.],
        [6., 6., 6., 6.],
        [6., 6., 6., 6.],
        [6., 6., 6., 6.]])

二、torch.mm

这是我们在线性代数课程中学习的矩阵乘法。该函数源码定义为torch.mm(input,mat2,out=None) ,参数与返回值均为tensor形式。

a=torch.ones(4,3)  
b=2*torch.ones(3,2)  
c=torch.empty(4,2)  
torch.mm(a,b,out=c)  
print(torch.mm(a,b))  
print( c )

#输出结果为
tensor([[6., 6.],
        [6., 6.],
        [6., 6.],
        [6., 6.]])
tensor([[6., 6.],
        [6., 6.],
        [6., 6.],
        [6., 6.]])

三、torch.matmul

这个矩阵乘法是在torch.mm的基础上增加了广播机制,源码定义为torch.matmul(input,other,out=None)
其基本运算规则如下:

  • 如果两个参数都为一维,则等价于torch.mul,需要注意的是:此时的out不接受任何参数
  • 如果两个张量都为二维且符合矩阵相乘规则,或第一个参数为一维(长度为m,这里等价为大小为1* m),第二个参数为二维(大小为m* n)则运算等价于torch.mm
  • 如果第一个参数为二维(大小m* n),第二个参数为一维(长度为n),这里第二个参数会进行转置成为n* 1的列向量,随后进行矩阵相乘,将得到的结果再进行转置,最终返回一个大小为1* m的向量
tensor1 = torch.tensor([[1,1,1,1],[2,2,2,2],[3,3,3,3]],dtype=torch.float32)
tensor2 = torch.ones(4)
print(tensor1.size())
print(tensor2.size())
print(torch.matmul(tensor1, tensor2))
#输出结果为:
torch.Size([3, 4])
torch.Size([4])
tensor([ 4.,  8., 12.])
  • 最后的情况就是任意一个参数至少为3维, 当前面的维度相同且最后两个维度符合二维矩阵运算规则可进行计算,例如第一参数的大小为a* b * c * m,第二个参数的大小为a* b* m* d,则返回一个大小为a* b* c * d的张量,可触发广播机制。
tensor1 = torch.ones(1,4,3,2)
tensor2 = torch.ones(2,6)
print(torch.matmul(tensor1, tensor2).size())
#输出结果为:
torch.Size([1, 4, 3, 6])

参考文献:https://pytorch.org/docs/stable/torch.html?highlight=matmul#torch.matmul

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值