torch.mul()、torch.mm()和torch.matmul()等函数使用法则

引 言  torch中的tensor张量之间相乘操作分为矩阵乘法和元素级乘法,两种乘法运算方式对于初学者而言很容易混淆。结合本人在实践操作中的经验,将pytorch中常用torch.mul()、torch.mm()和torch.matmul()等函数的用法进行详细介绍和举例说明。

目录

一、torch.mul()矩阵元素级乘法函数

二、torch.mm()二维矩阵乘法函数

三、torch.matmul()矩阵乘法函数

3.1 1维向量×1维向量的内积运算

3.2 1维向量×2维矩阵或2维矩阵×1维向量

3.3 2维矩阵×2维矩阵

3.4 三维矩阵相乘

四、torch.mv()矩阵向量乘法函数

五、@运算符的矩阵乘法

六、总结与注意


一、torch.mul()矩阵元素级乘法函数

torch.mul()函数主要对矩阵中的元素实施Hadamard积运算,该运算属于元素级相乘操作。可以直接使用“ * ”替换torch.mul()函数。在矩阵运算中,要求两个矩阵的维度相同,矩阵A\epsilon R^{m\times n},B\epsilon R^{m\times n},矩阵A和B的Hadamard积为:

 A\odot B = \begin{bmatrix} a_{11}b_{11}& a_{12} b_{12}&\cdots &a_{1n} b_{1n}& \\ a_{21}b_{21}& a_{22} b_{22}&\cdots &a_{2n} b_{2n}& \\ \vdots &\vdots & \ddots &\vdots \\ a_{m1}b_{m1}& a_{m2} b_{m2}&\cdots &a_{mn} b_{mn}& \end{bmatrix}

矩阵元素级乘法也可以用于向量×矩阵的情况,此时要求向量的长度与矩阵最后一个维度相同,采用广播机制变成与矩阵相同的形状,随后进行逐元素相乘操作。

x = torch.tensor([[1,1],[3,3],[4,4]])
y = torch.tensor([2,2])
out1 = torch.mul(x,y)    #等价于out1 = x*y


#结果
tensor([[2, 2],
        [6, 6],
        [8, 8]])

二、torch.mm()二维矩阵乘法函数

torch.mm()只适合于二维矩阵乘法运算,如果矩阵维度超过两个维度则会报错。二维矩阵乘法运算要求第一个矩阵的列数与第二个矩阵的行数相同

import torch

A = torch.randint(1,5,size=(2,3))
B = torch.randint(1,5,(3,2))
print('A: \n',A)
print('B: \n',B)
result = torch.mm(A,B)
print('result: \n {}'.format(result))

##结果##
A: 
 tensor([[2, 3, 2],
        [1, 4, 4]])
B: 
 tensor([[2, 2],
        [4, 4],
        [2, 3]])
result: 
 tensor([[20, 22],
        [26, 30]])

三、torch.matmul()矩阵乘法函数

torch.matmul()属于广义矩阵乘法函数操作,适用形式有:1维向量×1维向量,1维向量×2维矩阵,2维矩阵×1维向量,任意维度矩阵相乘等。每种情况的具体使用会结合示例代码逐一介绍。

3.1 1维向量×1维向量的内积运算

torch.matmul()函数作用于两个1维向量运算时,两个向量长度相同,主要对两个1维向量进行内积运算(结果为标量)。功能与torch.dot()函数相同(torch.dot()函数只能用于1维向量运算)。

x = torch.tensor([2,3,4])
y = torch.tensor([2,2,2])
out1 = torch.matmul(x,y)  #out1 : tensor(18)
out2 = torch.dot(x,y)     #out2 : tensor(18)

3.2 1维向量×2维矩阵或2维矩阵×1维向量

向量与矩阵做矩阵乘法运算时,需对向量进行增维操作,将其变成2维矩阵,矩阵相乘结束后,结果中增加的维度需要被删除。

1)m维向量与(m×n)维矩阵相乘,需先将向量变成维度为(1×m)矩阵,矩阵乘法维度变化:(1×m)×(m×n)->(1×n),生成的矩阵需删除新增维度,删除后的结果矩阵变成长度为n的1维向量。

x = torch.tensor([2,3])
y = torch.tensor([[1,1,1],[2,2,2]])
out = torch.matmul(x,y)  #out:tensor([8, 8, 8])
print(out.shape)         #torch.Size([3])

2)(m×n)维矩阵与n维向量相乘,则将向量增维成(n×1)矩阵,矩阵维度变化:(m×n)×(n×1)->(m×1),运算结果需删除新增维度,降维成m维向量。

x = torch.tensor([[3,3,3],[4,4,4]])
y = torch.tensor([2,2,2])
out = torch.matmul(x,y)  #out:tensor([18, 24])
print(out.shape)         #out.shape:torch.Size([2])

3.3 2维矩阵×2维矩阵

两个矩阵相乘时,torch.matmul()函数等价于torch.mm()函数:(m,n)×(n,t)->(m,t)

x = torch.tensor([[1,1],[3,3],[4,4]])
y = torch.tensor([[2,2,2],[5,5,5]])

out1 = torch.matmul(x,y)
print(f"out1: {out1}")

out2 = torch.mm(x,y)
print(f"out2: {out2}")

##结果##
out1: tensor([[ 7,  7,  7],
        [21, 21, 21],
        [28, 28, 28]])
out2: tensor([[ 7,  7,  7],
        [21, 21, 21],
        [28, 28, 28]])

3.4 三维矩阵相乘

对于高于二维的矩阵,第一个矩阵最后一个维度必须和第二个矩阵的倒数第二维度相同。如果是两个三维矩阵相乘,也可以使用torch.bmm()。

x = torch.randn(3,4,5)
y = torch.randn(3,5,2)
result = torch.matmul(x,y)
print(result.shape)    #shape: torch.Size([3, 4, 2])

四、torch.mv()矩阵向量乘法函数

torch.mv()用于执行2维矩阵×1维向量操作,矩阵的最后一个维度与向量长度必须相同。内部运算机理是先对向量末尾进行增维操作变成矩阵,执行矩阵乘法操作后,删除结果的最后一个维度。也可以采用上文提到的torch.matmul()。

x = torch.randint(1,4,(3,5))
y = torch.randint(1,4,(5,))
print(f"x: {x}")
print(f"y: {y}")
result = torch.mv(x,y)
print("result: {}".format(result))
print(result.shape)

##结果##
x: tensor([[2, 1, 1, 1, 2],
        [1, 1, 1, 2, 3],
        [1, 2, 1, 3, 2]])
y: tensor([1, 1, 3, 2, 2])
result: tensor([12, 15, 16])
torch.Size([3])

五、@运算符的矩阵乘法

  • 若mat1和mat2都是两个一维向量,那么对应操作就是torch.dot()
  • 若mat1是二维矩阵,mat2是一维向量,那么对应操作就是torch.mv()
  • 若mat1和mat2都是两个二维矩阵,那么对应操作就是torch.mm()

六、总结与注意

torch.mul()属于元素级操作,参与运算的矩阵要求形状相同,如果是向量与矩阵相乘,要求向量的长度与矩阵最后一个维度相同。torch.mm()只能执行二维矩阵运算,torch.matmul()适用于多维度矩阵乘法运算。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值