Pytorch-模型的保存于加载

简介

Pytorch中的序列化和反序列化

  1. troch.save
    主要参数
  • obj:对象
  • f:输出路径
  1. torch.load
    主要参数
  • f:文件路径
  • map_location:指定存放位置,cpu或者gpu

对于保存有两种方法:
1.保存整个Moucle, torch.save(net,path)
2.保存模型的参数:
state_dictt=net.state_dict()
torch.save(state_dict,path)

#方式1加载模型
path_model='./model.pkl'
net_load=torch.load(path_model)

#方式2加载模型
path_state_dict="./model_state_dict.pkl"
sate_dict_load=torch.load(path_state_dict)

net.load_dict(state_dict_load)

断点续训练-checkpoint

需要保存那些信息?
在这里插入图片描述
只有模型和优化器的参数需要保存,此外还需要记录epoch
在这里插入图片描述

checkpoint_interval = 5
#中间省略了若干训练的代码
#保存check_point
if (epoch+1) % checkpoint_interval == 0:                         
                                                                 
    checkpoint = {"model_state_dict": net.state_dict(),          
                  "optimizer_state_dict": optimizer.state_dict(),
                  "epoch": epoch}                                
    path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)  
    torch.save(checkpoint, path_checkpoint)   

#加载check_point
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch                   
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值