1 问题描述
今天在使用PyTorch搭建模型时,出现了一个问题:
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'weight' in call to _thnn_conv2d_forward
是在val阶段的forward函数的第一句出现的,
# 首先进行维度提升以适应torch的要求 img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0) img = img.to(self.device) x = img # x = x.float() # Max pooling over a (2, 2) window x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # 错误定位到这里
2 解决方案
可以看到这里出错的原因是weight的不匹配,
那么weight代表了什么含义呢?(其实后来解决了我才知道),这里的weight就是就是模型的parameters,
因为我们在Net的初始化代码中没有对模型参数的变量类型进行设置,所以默认就是Float类型;
而我们传入的变量x(也就是图像数据),其实是inte64类型,也torch的计算时,他转换成了Long类型,
而这里的模型参数依然是Float类型,所以出现了类型不一致的问题;
解决方法就是,在将变量x输入到网络中去之前,对其进行类型转换,
x = x.float()
这样既可以保证变量类型与模型参数类型相匹配,同时我们还可以使用Float类型进行运算,使得结果的精度更高;
3 致谢
感谢Imanol的帮助,
原文链接如下: