最近使用Pytorch在学习一个深度学习项目,在模型保存和加载过程中遇到了问题,最终通过在网卡查找资料得已解决,故以此记之,以备忘却。
首先,是在使用多GPU进行模型训练的过程中,在保存模型参数时,应该使用类似如下代码进行保存:
torch.save({
'epoch': epoch,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict()
}, 'results/checkpoint_net.pth')
对应的在加载模型参数时,使用如下代码进行加载是没有问题的:
checkpoint = torch.load('./results/checkpoint_net.pth')
model.load_state_dict(checkpoint['model'])
一般情况下,在保存模型时我们不会发现会有什么不对,而是在需要加载模型参数时,才发现加载报错了。比如:
这时我们需要回头检查我们在保存模型参数时,是否有哪里不对。比如我这次就是这样的,写代码的时候并没有考虑到多GPU的情况,所以保存代码如下:
torch.save({
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()