pytorch基础知识整理(三)模型保存与加载

1, torch.save(); troch.load()

torch.save()使用python的pickle模块把目标保存到磁盘,可以用来保存模型、张量、字典等,文件后缀名一般用pth或pt或pkl。torch.load()使用python的pickle模块实现从磁盘加载。可以用此来直接保存或加载完整模型:

torch.save(model, 'PATH.pth')
model = torch.load('PATH.pth')

注意:pytorch1.6以后保存的模型使用zip压缩,所以保存的模型无法被1.6以前的版本加载,如果要跨版本使用,需要做以下修改

torch.save(model, 'PATH.pth', _use_new_zipfile_serialization=False)

2, .state_dict(); .load_state_dict()

模型的框架已经在程序代码中了,因此训练好的模型只需要保存模型的参数即可供推理使用。model.state_dict()以字典的形式保存模型的参数,字典的键是参数名,值是参数值的张量。得到状态字典后还需用torch.save()固化到磁盘。
除模型外,优化器optimizer也可以保存和加载状态字典。

torch.save(model.state_dict(), 'PATH.pth')
model.load_state_dict(torch.load('PATH.pth'))

注意在多卡GPU训练时,保存和加载模型需要在model后加上module,即

torch.save(model.module.state_dict(), 'PATH.pth')
model.module.load_state_dict(torch.load('PATH.pth'))

3, 保存checkpoint

如果是训练中途保存用于继续训练,就不仅要保存权重参数,还要保存当前epoch,优化器的状态,当前的损失值等,可以统一打包到一个字典中保存为checkpoint,此时文件后缀名一般用tar。

#保存:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)
##加载:
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值