torch.matmul
是 PyTorch 中用于执行矩阵乘法的函数。它的设计目的是为了处理广泛的输入形状和操作,包括矩阵乘法、向量内积、批量矩阵乘法等。
语法
torch.matmul(input, other, out=None)
参数说明
input
: 输入张量。other
: 第二个输入张量。out
(可选): 输出张量。
举例
- 如果输入都是 1-D 张量,执行向量的内积。
- 如果输入都是 2-D 张量,执行矩阵乘法。
- 如果输入中至少有一个张量的维度大于 2,执行批量矩阵乘法。
- 支持广播机制,根据 NumPy 广播规则进行自动广播。
矩阵乘法
import torch
# 创建两个矩阵
mat1 = torch.rand(2, 3)
mat2 = torch.rand(3, 4)
# 执行矩阵乘法
result_matmul = torch.matmul(mat1, mat2)
print(result_matmul.shape) # 输出: torch.Size([2, 4])
批量矩阵乘法
import torch
# 创建两个三维张量
mat1 = torch.rand(3, 2, 3)
mat2 = torch.rand(3, 3, 4)
# 执行批量矩阵乘法
result_matmul_batch = torch.matmul(mat1, mat2)
print(result_matmul_batch.shape) # 输出: torch.Size([3, 2, 4])
向量的内积
import torch
# 创建两个向量
vec1 = torch.rand(3)
vec2 = torch.rand(3)
# 执行向量的内积
result_inner_product = torch.matmul(vec1, vec2)
print(result_inner_product) # 输出: 一个标量值
总的来说,torch.matmul
是一个非常通用的函数,能够处理多种输入形状,包括矩阵乘法、向量内积、批量矩阵乘法等。根据输入的具体形状,它会自动选择适当的操作。