今天看到一行pytorch的代码
import torch
from torch.autograd import Variable
tensor = torch.FloatTensor([[1,2],[3,4]])
variable = Variable(tensor, requires_grad=True)
v_out = torch.mean(variable*variable)
很理所当然的理解为两个矩阵相乘,但是打印输出看的时候觉得不对
tensor([[ 1., 4.],
[ 9., 16.]], grad_fn=<MulBackward0>)
这里明显做了一个点乘
那如何才能让这两个variable变量做矩阵的乘法呢
print(torch.mm(variable,variable))
tensor([[ 7., 10.],
[15., 22.]], grad_fn=<MmBackward>)
嗯,这样就可以了