引 言 torch中的tensor张量之间相乘操作分为矩阵乘法和元素级乘法,两种乘法运算方式对于初学者而言很容易混淆。结合本人在实践操作中的经验,将pytorch中常用torch.mul()、torch.mm()和torch.matmul()等函数的用法进行详细介绍和举例说明。
目录
一、torch.mul()矩阵元素级乘法函数
torch.mul()函数主要对矩阵中的元素实施Hadamard积运算,该运算属于元素级相乘操作。可以直接使用“ * ”替换torch.mul()函数。在矩阵运算中,要求两个矩阵的维度相同,矩阵,
,矩阵A和B的Hadamard积为:
矩阵元素级乘法也可以用于向量×矩阵的情况,此时要求向量的长度与矩阵最后一个维度相同,采用广播机制变成与矩阵相同的形状,随后进行逐元素相乘操作。
x = torch.tensor([[1,1],[3,3],[4,4]])
y = torch.tensor([2,2])
out1 = torch.mul(x,y) #等价于out1 = x*y
#结果
tensor([[2, 2],
[6, 6],
[8, 8]])
二、torch.mm()二维矩阵乘法函数
torch.mm()只适合于二维矩阵乘法运算,如果矩阵维度超过两个维度则会报错。二维矩阵乘法运算要求第一个矩阵的列数与第二个矩阵的行数相同。
import torch
A = torch.randint(1,5,size=(2,3))
B = torch.randint(1,5,(3,2))
print('A: \n',A)
print('B: \n',B)
result = torch.mm(A,B)
print('result: \n {}'.format(result))
##结果##
A:
tensor([[2, 3, 2],
[1, 4, 4]])
B:
tensor([[2, 2],
[4, 4],
[2, 3]])
result:
tensor([[20, 22],
[26, 30]])
三、torch.matmul()矩阵乘法函数
torch.matmul()属于广义矩阵乘法函数操作,适用形式有:1维向量×1维向量,1维向量×2维矩阵,2维矩阵×1维向量,任意维度矩阵相乘等。每种情况的具体使用会结合示例代码逐一介绍。
3.1 1维向量×1维向量的内积运算
torch.matmul()函数作用于两个1维向量运算时,两个向量长度相同,主要对两个1维向量进行内积运算(结果为标量)。功能与torch.dot()函数相同(torch.dot()函数只能用于1维向量运算)。
x = torch.tensor([2,3,4])
y = torch.tensor([2,2,2])
out1 = torch.matmul(x,y) #out1 : tensor(18)
out2 = torch.dot(x,y) #out2 : tensor(18)
3.2 1维向量×2维矩阵或2维矩阵×1维向量
向量与矩阵做矩阵乘法运算时,需对向量进行增维操作,将其变成2维矩阵,矩阵相乘结束后,结果中增加的维度需要被删除。
1)m维向量与(m×n)维矩阵相乘,需先将向量变成维度为(1×m)矩阵,矩阵乘法维度变化:(1×m)×(m×n)->(1×n),生成的矩阵需删除新增维度,删除后的结果矩阵变成长度为n的1维向量。
x = torch.tensor([2,3])
y = torch.tensor([[1,1,1],[2,2,2]])
out = torch.matmul(x,y) #out:tensor([8, 8, 8])
print(out.shape) #torch.Size([3])
2)(m×n)维矩阵与n维向量相乘,则将向量增维成(n×1)矩阵,矩阵维度变化:(m×n)×(n×1)->(m×1),运算结果需删除新增维度,降维成m维向量。
x = torch.tensor([[3,3,3],[4,4,4]])
y = torch.tensor([2,2,2])
out = torch.matmul(x,y) #out:tensor([18, 24])
print(out.shape) #out.shape:torch.Size([2])
3.3 2维矩阵×2维矩阵
两个矩阵相乘时,torch.matmul()函数等价于torch.mm()函数:(m,n)×(n,t)->(m,t)
x = torch.tensor([[1,1],[3,3],[4,4]])
y = torch.tensor([[2,2,2],[5,5,5]])
out1 = torch.matmul(x,y)
print(f"out1: {out1}")
out2 = torch.mm(x,y)
print(f"out2: {out2}")
##结果##
out1: tensor([[ 7, 7, 7],
[21, 21, 21],
[28, 28, 28]])
out2: tensor([[ 7, 7, 7],
[21, 21, 21],
[28, 28, 28]])
3.4 三维矩阵相乘
对于高于二维的矩阵,第一个矩阵最后一个维度必须和第二个矩阵的倒数第二维度相同。如果是两个三维矩阵相乘,也可以使用torch.bmm()。
x = torch.randn(3,4,5)
y = torch.randn(3,5,2)
result = torch.matmul(x,y)
print(result.shape) #shape: torch.Size([3, 4, 2])
四、torch.mv()矩阵向量乘法函数
torch.mv()用于执行2维矩阵×1维向量操作,矩阵的最后一个维度与向量长度必须相同。内部运算机理是先对向量末尾进行增维操作变成矩阵,执行矩阵乘法操作后,删除结果的最后一个维度。也可以采用上文提到的torch.matmul()。
x = torch.randint(1,4,(3,5))
y = torch.randint(1,4,(5,))
print(f"x: {x}")
print(f"y: {y}")
result = torch.mv(x,y)
print("result: {}".format(result))
print(result.shape)
##结果##
x: tensor([[2, 1, 1, 1, 2],
[1, 1, 1, 2, 3],
[1, 2, 1, 3, 2]])
y: tensor([1, 1, 3, 2, 2])
result: tensor([12, 15, 16])
torch.Size([3])
五、@运算符的矩阵乘法
- 若mat1和mat2都是两个一维向量,那么对应操作就是torch.dot()
- 若mat1是二维矩阵,mat2是一维向量,那么对应操作就是torch.mv()
- 若mat1和mat2都是两个二维矩阵,那么对应操作就是torch.mm()
六、总结与注意
torch.mul()属于元素级操作,参与运算的矩阵要求形状相同,如果是向量与矩阵相乘,要求向量的长度与矩阵最后一个维度相同。torch.mm()只能执行二维矩阵运算,torch.matmul()适用于多维度矩阵乘法运算。