torch.bmm:https://blog.csdn.net/guotong1988/article/details/78707619
参考 https://blog.csdn.net/Real_Brilliant/article/details/85756477
batch matrix multiply 相当于矩阵的乘法。
>>> batch1 = torch.randn(10, 3, 4)
>>> batch2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(batch1, batch2)
>>> res.size()
torch.Size([10, 3, 5])
torch.mul(a, b)
是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵torch.mm(a, b)
是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵-
import torch a = torch.rand(1, 2) b = torch.rand(1, 2) c = torch.rand(2, 3) print(torch.mul(a, b)) # 返回 1*2 的tensor print(torch.mm(a, c)) # 返回 1*3 的tensor print(torch.mul(a, c)) # 由于a、b维度不同,报错
>>> import torch >>> a =[ [[1,2,3],[2,3,4]]] >>> a = torch.Tensor(a) >>> b =[ [[1,2,3],[2,3,4]]] >>> b = torch.Tensor(b) >>> c = torch.mul(a,b) >>> c tensor([[[ 1., 4., 9.], [ 4., 9., 16.]]]) >>> c.requires_grad False >>> d = a*b >>> d tensor([[[ 1., 4., 9.], [ 4., 9., 16.]]]) >>> d.requires_grad False >>>