使用PyTorch前向运算时出现“RuntimeError: Expected object of scalar type Long but got scalar type Float for ……”

1 问题描述

今天在使用PyTorch搭建模型时,出现了一个问题:

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'weight' in call to _thnn_conv2d_forward

是在val阶段的forward函数的第一句出现的,

# 首先进行维度提升以适应torch的要求
img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
img = img.to(self.device)
x = img
# x = x.float()

# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # 错误定位到这里

2 解决方案

可以看到这里出错的原因是weight的不匹配,

那么weight代表了什么含义呢?(其实后来解决了我才知道),这里的weight就是就是模型的parameters,

因为我们在Net的初始化代码中没有对模型参数的变量类型进行设置,所以默认就是Float类型;

而我们传入的变量x(也就是图像数据),其实是inte64类型,也torch的计算时,他转换成了Long类型,

而这里的模型参数依然是Float类型,所以出现了类型不一致的问题;

解决方法就是,在将变量x输入到网络中去之前,对其进行类型转换,

x = x.float()

这样既可以保证变量类型与模型参数类型相匹配,同时我们还可以使用Float类型进行运算,使得结果的精度更高;

 

3 致谢

感谢Imanol的帮助,

原文链接如下:

https://discuss.pytorch.org/t/runtimeerror-expected-object-of-scalar-type-double-but-got-scalar-type-float-for-argument-2-weight/38961/9?u=songyuc

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值