torch.bmm 函数是 PyTorch 中用于执行批量矩阵乘法(Batch Matrix-Matrix Multiplication)的函数。它的名字 "bmm" 表示 "batch matrix multiplication"。
注意:torch.bmm()是不带广播机制的,也就说需按照矩阵运算机制。
比如:[B,3,4]*[B,4,5]是可以的,而[B,3,2]*[B,8,5]是不可以的。
语法
torch.bmm(mat1, mat2)
参数说明
mat1和mat2是两个三维张量(或者可以被广播为三维张量)。mat1的形状为(batch, n, m)。mat2的形状为(batch, m, p)。
- 返回值是一个三维张量,其形状为
(batch, n, p)
举例
import torch
# 创建两个三维张量
mat1 = torch.rand(3, 2, 4)
mat2 = torch.rand(3, 4, 3)
# 执行批量矩阵乘法
result = torch.bmm(mat1, mat2)
print(result.shape) # 输出: torch.Size([3, 2, 3])
在上面的示例中,mat1 的形状是 (3, 2, 4),mat2 的形状是 (3, 4, 3)。执行 torch.bmm(mat1, mat2) 将得到一个形状为 (3, 2, 3) 的张量,其中 3 是 batch 的大小,2 是 mat1 的行数,3 是 mat2 的列数。
3万+

被折叠的 条评论
为什么被折叠?



