torch.baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) → Tensor
batch1和batch2执行批量的矩阵相乘,然后和input相加,得到output
如果batch1 是(b*n*m)的tensor,batch2是(b*m*p)的tensor,那么input必须是(b*n*p)的tensor.
output = input + batch1 * batch2
参数
M = torch.randn(10, 3, 5)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 5)
torch.baddbmm(M, batch1, batch2).size()
输出:
torch.Size([10, 3, 5])