模型保存与加载常用有两种方法,第一种是保存整个模型,包括模型的结构和参数;第二种是保存模型的参数。推荐使用第二种,因为模型一旦很大,第一种加载耗时长,其次第二种加载方式更加灵活,可以加载其他模型的预训练参数,从而使用迁移学习的方法减小训练时长。
一、保存/加载整个模型
- 保存模型:
torch.save(net, 'model_net1.pkl')
- 加载模型
net_parm = 'model_net1.pkl'
net = torch.load(net_parm)
二、保存/加载模型参数
- 保存参数:
torch.save({
'epoch': nums_epoch,
'state_dict': net.state_dict(),
}, 'model_net.pkl')
- 加载参数:
cuda_gpu = torch.cuda.is_available()
if(cuda_gpu):
net = torch.nn.DataParallel(net, device_ids=gpus).cuda()
if os.path.isfile(net_parm):
print(