【Pytorch】对比matual,mm和bmm函数

pytorch中提供了 matmulmmbmm等矩阵的乘法运算功能,但其具体计算细节和场景截然不同,应予以注意和区别。

1. torch.mm

该函数即为矩阵的乘法,torch.mm(tensor1, tenor2),参与计算的两个张量必须为二维张量(即矩阵),其结果张量out的维度关系满足: o u t ( p × q ) = t 1 ( p × m ) ∗ t 2 ( m × q ) out(p\times q)=t_1(p\times m)*t_2(m\times q) out(p×q)=t1(p×m)t2(m×q)

2. torch.bmm

该函数提供了batch维度的矩阵乘法运算,即batch内对应的矩阵两两相乘,因此要求参与计算的两个张量必须为三维张量,其中第一个维度为batch_size,必须相同,其结果张量 out的维度关系满足: o u t ( b × p × q ) = t 1 ( b × p × m ) ∗ t 2 ( b × m × q ) out(b\times p \times q)=t_1(b\times p\times m)*t_2(b\times m\times q) out(b×p×q)=t1(b×p×m)t2(b×m×q)

3. torch.matmul

该函数的功能相较于前面两个要丰富的多,其支持参与运算的两个张量有不同的维度,计算的机理也各不相同,具体包括:

(1) 两个张量均为1维张量(即向量)时,计算向量的内积

(2) 两个张量均为2维张量(即矩阵)时,计算矩阵的乘法

(3) 第一个向量为1维张量,第二个张量为2维张量,对第一个张量视情进行broadcast,然后进行矩阵乘法,在将上述结果squeeze为1维;

(4) 第二个向量为1维张量,第一个张量为2维张量,计算矩阵和向量的乘法;

(5) 两个向量维度至少为1,且其中至少一个张量的维度大于2;则先进行broadcast,然后进行bmm操作,最后将上述结果转换会broadcast之前的效果。

  • 9
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值