报错
File "/home/xovee/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/serialization.py", line 142, in validate_cuda_device
raise RuntimeError('Attempting to deserialize object on CUDA device '
RuntimeError: Attempting to deserialize object on CUDA device 1 but torch.cuda.device_count() is 1. Please use torch.load with map_location to map your storages to an existing device.
环境
- PyTorch 1.9
出错原因
显卡设备不匹配,原为'cuda:1'
,现为'cuda:0'
。
解决方案
原代码:
model.load_state_dict(model_path)
现代码:
loaded_state = torch.load(model_path, map_location='cuda:0')
model.load_state_dict(loaded_state)