Pytorch保存我们训练好的模型,然后加载用于测试
第一种方法
(1)保存
torch.save(model.state_dict(), PATH)
# example
torch.save(resnet50.state_dict(),'ckp/model.pth')
(2)恢复
model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
#example
resnet=resnet50(pretrained=True)
resnet.load_state_dict(torch.load('ckp/model.pth'))
第二种方法
(1)保存
torch.save (model, PATH)
(2)恢复
model = torch.load(PATH)