pytorch中的torch.bmm()用法
torch.bmm() 用于执行批量矩阵乘法,
import torch
# 创建两个张量
batch1 = torch.randn(10, 3, 4) # 形状为 (10, 3, 4)
batch2 = torch.randn(10, 4, 5) # 形状为 (10, 4, 5)
# 执行批量矩阵乘法
result = torch.bmm(batch1, batch2)
print(result.shape) # 输出 (10, 3, 5)
https://blog.csdn.net/weixin_63062756/article/details/130580454