源于多维tensor的乘法不知道应该应用哪个函数,然后发现了别人总结的好,特意借鉴过来,以备不时之需。
1. 二维矩阵乘法 torch.mm()
torch.mm(mat1, mat2, out=None)
其中
,
, 输出的
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
2. 三维带batch的矩阵乘法 torch.bmm()
由于神经网络训练一般采用mini-batch,经常输入的时三维带batch的矩阵,所以提供
torch.bmm(bmat1, bmat2,