torch.bmm(batch1, batch2, out=None) → Tensor
参数:
- batch1 (Tensor) – 第一批相乘矩阵
- batch2 (Tensor) – 第二批相乘矩阵
- out (Tensor, optional) – 输出张量
举例
-
import torch
-
batch1 = torch.randn(8, 2, 6)
-
batch2 = torch.randn(8, 6, 10)
-
res = torch.bmm(batch1, batch2)
-
print(res.size())
结果
torch.Size([8, 2, 10])