介绍
torch.mm()
和torch.matmul()
都是PyTorch库中用来进行矩阵乘法的函数,但它们在处理输入时有一些不同。
torch.mm()
torch.mm()
:这个函数仅接受二维矩阵作为输入,并进行矩阵乘法。如果输入的张量不是二维的,它会抛出一个错误。
torch.matmul()
torch.matmul()
:这个函数可以接受高于二维的张量,并进行适当的广播和矩阵乘法。对于二维矩阵,torch.matmul()
和torch.mm()
的行为是相同的。
例子
二维矩阵
import torch
a = torch.randn(3, 4)
b = torch.randn(4, 5)
c = torch.mm(a, b) # 使用torch.mm()
d = torch.matmul(a, b) # 使用torch.matmul()
print(c.shape) # torch.Size([3, 5])
print(d.shape) # torch.Size([3, 5])
高维矩阵
e = torch.randn(2, 3, 4)
f = torch.randn(2, 4, 5)
# torch.mm仅接受二维矩阵作为输入
g = torch.mm(e, f) #报错:RuntimeError: self must be a matrix
g = torch.matmul(e, f)
print(g.shape) # torch.Size([2, 3, 5])
高维广播
除了最后两个维度外,其他的维度相等
x = torch.randn(2, 4, 3, 4)
v = torch.randn(2, 4, 4, 5)
g = torch.matmul(x, v)
print(g.shape) # torch.Size([2, 4, 3, 5])
除了最后两个维度外,其他的不相等的维度,其中一个是1,会自动广播
x1 = torch.randn(2, 1, 3, 4)
v1 = torch.randn(1, 4, 4, 5)
g1 = torch.matmul(x1, v1)
print(g1.shape) # torch.Size([2, 4, 3, 5])
除了最后两个维度外,其他的不相等的维度,且不都是1,报错
x2 = torch.randn(3, 4, 3, 4)
v2 = torch.randn(2, 1, 4, 5)
g2 = torch.matmul(x2, v2)
# 张量a的第0维的大小是3,而张量b的第0维的大小是2,这两者不匹配,都不是1,不能广播
print(g2.shape) # The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0