常见报错:RuntimeError: expected scalar type Long but found Float

RuntimeError: expected scalar type Long but found Float

这是一个非常常见的报错,我已经遇到过这个报错很多次了,但是之前没有仔细研究过,今天好好好看了看,终于找到了原因。
首先把导致报错的代码写出来:

import torch
import torch.nn as nn

v = torch.tensor([0])
m = nn.Linear(1, 10)
m(v)

短短的几行代码,就是初始化了一个值为0的v、一个网络m,运行后爆出了一大堆错:

Traceback (most recent call last):
  File "D:\ProgramData\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-183-2ddaa24c9bb3>", line 1, in <module>
    m(v)
  File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1692, in linear
    output = input.matmul(weight.t())
RuntimeError: expected scalar type Long but found Float

注意到导致报错的代码: output = input.matmul(weight.t())
因为input也就是我们的v是torch.long类型的而weight是torch.float类型
所以在做矩阵乘法的时候这两种类型的不一致导致了报错
解决方法就是把v的dtype显示地设置成torch.float代码就成功运行了:

import torch
import torch.nn as nn
# dtype=torch.float必不可少
v = torch.tensor([0], dtype=torch.float)
m = nn.Linear(1, 10)
m(v)
Out[11]: 
tensor([-0.0628, -0.2544,  0.1313, -0.9293, -0.1259, -0.3151,  0.0729, -0.3097,
         0.8988,  0.1230], grad_fn=<AddBackward0>)
  • 51
    点赞
  • 52
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值