在 PyTorch 中,torch.matmul
和 @
操作符都用于执行矩阵乘法,但它们在某些特定情况下有略微不同的行为。下面我们详细介绍两者之间的区别与相同之处。
1. torch.matmul
torch.matmul
是一个通用的矩阵乘法函数,它能够处理不同维度的张量,依据输入张量的维度自动决定执行何种类型的乘法。
torch.matmul
的行为:
- 如果输入是 1D 向量:执行 向量内积。
- 如果输入是 2D 矩阵:执行 标准的矩阵乘法。
- 如果输入是高于 2D 的张量:执行 批量矩阵乘法,即对批量中的每对矩阵分别执行矩阵乘法。
例子:
import torch
# 1D 张量(向量)的点积
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.matmul(a, b)
print(result) # 输出: 32 (即 1*4 + 2*5 + 3*6)
# 2D 矩阵的标准矩阵乘法
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.matmul(A, B)
print(result) # 输出: tensor([[19, 22], [43, 50]])
# 3D 张量的批量矩阵乘法
C = torch.randn(10, 3, 4) # 形状 (10, 3, 4) 的批量矩阵
D = torch.randn(10, 4, 5) # 形状 (10, 4, 5) 的批量矩阵
result = torch.matmul(C, D) # 批量中的每个 (3, 4) 矩阵和 (4, 5) 矩阵相乘,输出形状为 (10, 3, 5)
print(result.shape) # 输出: torch.Size([10, 3, 5])
2. @
操作符
@
操作符是在 Python 3.5 中引入的用于矩阵乘法的快捷符号。它也可以处理 1D 向量、2D 矩阵 以及高维张量,但它的行为与 torch.matmul
是一致的。
@
操作符的行为:
- 如果输入是 1D 向量:执行 向量内积。
- 如果输入是 2D 矩阵:执行 矩阵乘法。
- 如果输入是高于 2D 的张量:执行 批量矩阵乘法(与
torch.matmul
相同)。
例子:
import torch
# 1D 向量的点积
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = a @ b
print(result) # 输出: 32
# 2D 矩阵乘法
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = A @ B
print(result) # 输出: tensor([[19, 22], [43, 50]])
# 3D 张量的批量矩阵乘法
C = torch.randn(10, 3, 4)
D = torch.randn(10, 4, 5)
result = C @ D
print(result.shape) # 输出: torch.Size([10, 3, 5])
3. 相同点:
- 矩阵和向量的乘法:
torch.matmul
和@
对 1D 和 2D 张量的乘法行为是完全相同的。 - 批量矩阵乘法:对于更高维度的张量,它们都支持批量矩阵乘法,并且表现一致。
4. 不同点:
唯一的区别在于可读性和使用场景:
@
操作符:作为 语法糖,更简洁直观,在书写矩阵乘法时更具可读性,尤其是在 Python 代码中类似于数学符号。torch.matmul
:作为 PyTorch 提供的函数,功能上更明确,适合程序化或需要在函数式编程中灵活调用的场景。
总结:
torch.matmul
是 PyTorch 中的通用矩阵乘法函数,适用于从向量到批量矩阵的各种乘法场景。@
操作符是 Python 的矩阵乘法符号,行为与torch.matmul
相同,但更简洁。- 二者在功能上基本一致,选择哪个取决于代码的风格和偏好。如果你想要书写更简洁的代码,
@
操作符是一个很好的选择;如果你需要在函数或复杂场景中调用矩阵乘法,torch.matmul
更为合适。