从官方文档可以看出
-
mm只能进行矩阵乘法,也就是输入的两个tensor维度只能是(b * m), (m * k) 得到(b * k)
-
bmm是两个三维张量相乘, 两个tensor维度是,(b * m * n), (b * n * k) 得到(b * m * k)
-
matmul可以进行张量乘法, 输入可以是高维.
总结:
对位相乘用torch.mul;二维矩阵乘法用torch.mm;batch二维矩阵用torch.bmm;batch,广播用torch.matmul
参考:
https://blog.csdn.net/Real_Brilliant/article/details/85756477