问题描述:
在将输入数据送入到网络模型的时候,报错:
RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
解决:
一开始在网上查找相关问题描述,其中绝大部分说的是显卡中缓存太多,需要清理缓存然后重新启动。
但是按照他们介绍的步骤,发现问题并没有得到解决,而且查看显卡状态,显存状态良好。
最后在 stackoverflow 上找到的解决方法
方法:
如果情况和我相同的话,我们输入的数据应该已经放在 GPU 上去了,我们可以去掉例如 .cuda() 的操作,将数据放回到 CPU 上重新运行代码,可以看到一个更加清楚地错误描述,其实根本问题是我们送了一个 Long 类型的数据(具体情况可能有差异),但是模型期望的却是一个 Float 类型的数据:
解决:
data = data.float().cuda()
具体解决可能有差异,但是思路应该是这样的。
END~