本篇其实与PyTorch学习笔记:使用state_dict来保存和加载模型是高度关联的,之所以单独拎出来写,主要是想突出它的重要性。
首先来描述一个本人实际遇到的问题:
首先在GPU服务器上训练了一个ResNet34的模型,然后将该模型在本人PC机(没有GPU)上进行推理,模型加载代码如下:
# load model weights
weights_path = "./resNet34.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
# 直接加载模型
model.load_state_dict(torch.load(weights_path))
结果运行时出现如下错误:
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
这就引出了今天的问题ÿ