RuntimeError: expected scalar type Long but found Float
这是一个非常常见的报错,我已经遇到过这个报错很多次了,但是之前没有仔细研究过,今天好好好看了看,终于找到了原因。
首先把导致报错的代码写出来:
import torch
import torch.nn as nn
v = torch.tensor([0])
m = nn.Linear(1, 10)
m(v)
短短的几行代码,就是初始化了一个值为0的v、一个网络m,运行后爆出了一大堆错:
Traceback (most recent call last):
File "D:\ProgramData\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3418, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-183-2ddaa24c9bb3>", line 1, in <module>
m(v)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1692, in linear
output = input.matmul(weight.t())
RuntimeError: expected scalar type Long but found Float
注意到导致报错的代码: output = input.matmul(weight.t())
因为input也就是我们的v是torch.long类型的而weight是torch.float类型
所以在做矩阵乘法的时候这两种类型的不一致导致了报错
解决方法就是把v的dtype显示地设置成torch.float代码就成功运行了:
import torch
import torch.nn as nn
# dtype=torch.float必不可少
v = torch.tensor([0], dtype=torch.float)
m = nn.Linear(1, 10)
m(v)
Out[11]:
tensor([-0.0628, -0.2544, 0.1313, -0.9293, -0.1259, -0.3151, 0.0729, -0.3097,
0.8988, 0.1230], grad_fn=<AddBackward0>)