在保存模型的时候,我们没有保存整个模型,只保存了模型的参数,因为保存整个模型可能出现一些意外的错误。
保存模型参数为如下代码:
torch.save(model.state_dict(),path)
1 cpu训练->cpu测试Or gpu训练->gpu测试
model = torch.load('model_test.pth')
2 cpu训练->gpu测试
model = torch.load('model_test.pth', map_location=lambda storage, loc: storage.cuda(1))
3 gpu:1训练->gpu:0测试
model = torch.load('model_test.pth', map_location={'cuda:1':'cuda:0'})
4 gpu训练->cpu测试
model = torch.load('model_test.pth', map_location=lambda storage, loc: storage)