Pytorch基本用法 乘法函数详解:5个张量乘法function

1、基础

张量维度:维度个数和维度大小;.ndim可查看维度个数,.shape可查看维度大小。

如下代码,张量a:维度个数为2,是一个2维张量;维度大小为[2,3],即第0维的维度大小为2,第1维为3。

>>> a=torch.arange(8).reshape(2,4)
>>> a
tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])
>>> a.ndim
2
>>> a.shape
torch.Size([2, 4])
>>> a.size()
torch.Size([2, 4])

2、torch.matmul

这是最复杂也是功能最强大的乘法函数 ,可实现混合矩阵乘法。

参数:

  • input(张量):第1个张量。

  • other(张量):第2个张量。

  • out(张量):结果张量,等同于torch.matmul函数的返回值。

返回: 

  • 张量。

CASE:

(1)若2个张量皆为1维张量,即向量点积,等价于torch.dot函数。结果为scalar标量

(2)若2个张量皆为2维张量,即矩阵乘法,等价于torch.mm函数。结果为2维张量

(3)若第1个张量为1维张量,假设维度为[k],第二个张量为2维张量,假设维度为[k,p]。第一个张量会在左边进行维度扩展,维度变为[1,k],然后再进行矩阵乘法,获得维度为[1,p]的张量,然后再去掉扩展的维度,最后结果张量维度为[p]。

(4)若第1个张量为2维张量,假设维度为[k,n],第2个张量为1维张量,假设维度为[n]。第2个张量会在右边进行维度扩展,维度变为[n,1],然后再执行矩阵乘法,获得维度为[k,1]的张量,最后再去掉扩展的维度,获得维度为[k]的结果张量。 

(5)如果2个张量的维度均至少为1,且其中至少一个张量维度大于2,那么matmul将执行批矩阵乘法操作:默认使用两个张量的后两维度执行矩阵乘法,其他维度作为batch维。这个复杂,举个例子:

若两个张量均为3维张量,矩阵个数相等(第0维大小相等)且后两维满足矩阵乘法约束,那么调用matmul等价于调用torch.bmm函数

import torch
a=torch.arange(27).reshape(3,3,3)
print(a.size())
print(a.ndim)

b=torch.arange(1,10).reshape(3,3,1)
print(b.size())
print(b.ndim)

c1=torch.matmul(a,b)
c2=torch.bmm(a,b)
c1.equal(c2)

>>torch.Size([3, 3, 3])
>>3
>>torch.Size([3, 3, 1])
>>3
>>True

可能出现torch.bmm无法执行,但torch.matmul仍可执行:两个张量的后两维需满足矩阵乘法约束,不满足的情形会进行维度扩展([2]>>[2,1]),其他维则会通过广播操作([2]>>[2,1]>>[2,1,2,1])对齐。

a=torch.arange(12).reshape(2,1,3,2)
b=torch.arange(2)
c=torch.matmul(a,b)

print(a.ndim,b.ndim,c.ndim)
print(a.shape)
print(b.shape)
print(c.shape)

>>4 1 3
>>torch.Size([2, 1, 3, 2])
>>torch.Size([2])
>>torch.Size([2, 1, 3])

3、torch.dot

功能:向量点积。

参数

  • input(张量):第一个张量。

  • other(张量):第二个张量。

  • out(张量):结果张量,等同于dot函数的返回值。

返回值:张量(标量)。

重点:只支持具有相同元素个数的两个一维张量做点积操作。

4、torch.mm

功能:矩阵乘法。

参数:

input(张量):第一个矩阵,即2维张量。

mat2(张量):第二个矩阵,即2维张量。

out(张量):结果张量。

重点:torch.mm不会进行广播操作,它严格要求两个张量满足维度约束。即,假设两个张量分别为a、b,要求a.size()[1]=b.size()[0]。
4、torch.bmm

功能:批量矩阵乘法。

参数:

input(张量):第一批矩阵,即3维张量,第0维表示批大小。

mat2(张量):第二批矩阵,即3维张量,第0维表示批大小。

out(张量):结果张量,等同于torch.bmm函数返回值。也是3维张量,第0维表示批大小。

返回值:三维张量,第0维表示批大小。

重点:bmm不会进行广播操作,它严格要求两个张量均为三维张量,且第0维大小相等(表示有多少个矩阵),其他两维满足矩阵乘法约束。 即,假设两个张量分别为a、b,要求a.size()[0]=b.size()[0]且a.size()[-1]==b.size()[1]。
5、torch.mul

功能:逐元素相乘。

参数:

input(张量):第一个张量。

other(张量):第二个张量。

out(张量):结果张量,等同于mul函数的返回值。

返回值:张量。

重点:要求两个张量维度相同,即a.size()==b.size();若不同,则通过广播操作将相乘的两个张量的维度变得相同。同时,它的广播操作还会将两个张量类型统一。
6、总结

# 向量点积运算。要求输入为一维张量且类型相同、元素个数相同,输出为scaler标量。
torch.dot 
# 矩阵乘法运算,不支持广播操作。要求输入为二维张量且类型相同,维度大小满足矩阵乘法约束。
torch.mm
# 批矩阵乘法运算,不支持广播操作。要求输入为三维张量且类型相同,第0维大小相等,后两维大小满足矩阵乘法约束。
torch.bmm
# 混合矩阵乘法运算,包括向量点积、矩阵乘法、批矩阵乘法,且支持广播操作(仅针对维度)。要求输入张量类型相同,具体行为根据维度可以分五种情况。
torch.matmul
# 逐元素乘法,等价于*,支持广播操作(包括维度及类型)。无特殊要求或约束。
torch.mul

更多可参考:一文整理5个Pytorch张量乘法函数_pytorch 张量乘法_AI算法小喵的博客-CSDN博客

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

sskay_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值