我的CUDA version=11.4
百度、谷歌上能找的办法都找了,有各种解决办法
- 减小batch_size的
- 把所有net和输入都放入到相同设备的
- net的train和eval状态检查的
- 防止内存中的数据混乱,将前面输出clone的
- 设置CUDA_LAUNCH_BLOCKING=1的
- 设置torch.cuda.set_device(1)的
- cudnn版本不对,需要重装的
都试了个遍,都没用
最后试着将pytorch的版本回退了一下,由v1.12.0退到v1.11.0之后,发现问题解决了
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch