【pytorch】矩阵乘法 mm bmm matmul mul @ * 总结

0.内积torch.dot()

在这里插入图片描述

1. 二维矩阵乘法 torch.mm()

torch.mm(mat1, mat2, out=None)
其中mat1(nxm), mat2(mxd), 输出out(nxd)
一般只用来计算两个二维矩阵的矩阵乘法,而且不支持broadcast操作。
在这里插入图片描述

torch.mm(input, mat2, out=None) → Tensor
#对矩阵imput和mat2执行矩阵乘法。 如果input为(n x m)张量,则mat2为(m x p)张量,out将为(n x p)张量。
#官方提示此功能不广播。有关广播的矩阵乘法,请参见torch.matmul()。
#example
>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

2.三维带Batch矩阵乘法 torch.bmm()

torch.bmm(bmat1, bmat2, out=None)
其中bmat1(B x n x m), bmat2(B x m x d), 输出out(B x n x d)
该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
在这里插入图片描述
在这里插入图片描述

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

3."混合"矩阵乘法 torch.matmul()

torch.matmul(input, other, out=None)
支持broadcast操作.
在这里插入图片描述
在这里插入图片描述

>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])
>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

4.元素乘法torch.mul()

在这里插入图片描述

>>> a = torch.randn(3)
>>> a
tensor([ 0.2015, -0.4255,  2.6087])
>>> torch.mul(a, 100)
tensor([  20.1494,  -42.5491,  260.8663])

在这里插入图片描述

>>> a = torch.randn(4, 1)
>>> a
tensor([[ 1.1207],
        [-0.3137],
        [ 0.0700],
        [ 0.8378]])
>>> b = torch.randn(1, 4)
>>> b
tensor([[ 0.5146,  0.1216, -0.5244,  2.2382]])
>>> torch.mul(a, b)
tensor([[ 0.5767,  0.1363, -0.5877,  2.5083],
        [-0.1614, -0.0382,  0.1645, -0.7021],
        [ 0.0360,  0.0085, -0.0367,  0.1567],
        [ 0.4312,  0.1019, -0.4394,  1.8753]])

6.两个乘法操作符 @ 和 *

  • @操作符可以执行矩阵乘法操作,类似 torch.mm(), torch.bmm(), torch.matmul() ;
  • *乘法操作可以执行元素乘法,使用方法类似 torch.mul()
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值