成功解决
RuntimeError: Attempting to deserialize object on CUDA device 1 but torch.cuda.device_count() is 1.
报错内容
程序在这一步报错
checkpoint = torch.load(‘model5_4.pt’)
以上问题描述是说未获取到当前环境下的 cuda,因为我的模型是在服务器上跑的,下载到本地后环境不同。
解决方法
若你当前在只有 CPU 环境下运行的话,需要加上map_location=torch.device(‘cpu’)。
若你当前在有 CUDA环境下运行的话,需要加上map_location=torch.device(‘cuda’)。
checkpoint = torch.load(‘model5_4.pt’)
即换成:
checkpoint = torch.load('model5_4.pt',map_location='cuda')
运行成功!不报错了!