pytorch快速上手(5)-----pytorch模型的保存加载与断点恢复训练

模型的保存与加载

PyTorch中的保存(序列化,从内存到硬盘)与反序列化(加载,从硬盘到内存)

  • torch.save主要参数: obj:对象 、f:输出路径
  • torch.load 主要参数 :f:文件路径 、map_location:指定存放位置、 cpu or gpu
    在这里插入图片描述

一、常见的模型保存的两种方法:

1、保存整个Module

torch.save(net, path)

在这里插入图片描述

2、保存模型参数

state_dict = net.state_dict()
torch.save(state_dict , path)

二、训练过程中自定义保存内容 与 断点恢复训练

#加载恢复
if RESUME:#是否恢复
    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 断点模型文件路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

    start_epoch = checkpoint['epoch']  # 设置开始的epoch
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler



#保存
for epoch in range(start_epoch + 1, 80):

    optimizer.zero_grad()

    optimizer.step()
    lr_schedule.step()


    if epoch %20 == 19:#每隔20个epoch保存一次模型
        print('epoch:',epoch)
        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
        #自定义要保存的参数信息
        checkpoint = {
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            "epoch": epoch,
            'lr_schedule': lr_schedule.state_dict()
        }
        if not os.path.isdir("./model_parameter/test"):
            os.mkdir("./model_parameter/test")
        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))

在这里插入图片描述
在这里插入图片描述

模型层该改变时,再次载入之前模型的权重,只需要 model.load_state_dict(torch.load(PATH), strict = False)

在这里插入图片描述
在这里插入图片描述

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值