pytorch断点续传


前言

当在模型训练过程中遇到下面的情况时我们就需要断点续传的技巧了

  • 本地训练到一半但由于有其他事情或事故必须主动或被动中断正在训练的模型等待后续再继续训练
  • 云端训练模型时由于平台的不稳定性导致训练中断,例如colab等。

一、断点续传的作用?

断点续传会在模型训练到一定时期时保存一次当前训练的数据,保存下的数据是以字典的形式序列化存储的,后续再通过pytorch反序列化读取即可。

二、具体步骤

1.保存断点

首先需要设置一个保存周期变量checkpoint_interval,具体的值可以自定义,值过小的话保存次数过多训练时间就会增强,过大就容易导致马上就要达到一个保存周期时训练中断,整个周期几乎是重新训练。具体代码如下:

    checkpoint_interval = 3
    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint = {"model_state_dict": model.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)

这里存储了模型的参数,优化器的参数以及当前epoch。存储优化器的参数是因为优化器中存储了当前权重更新状态的相关参数。

2.加载断点

这里设置了一个start_epoch,它的值来自断点中存储的epoch值,代表当前要继续的epoch值。resume是个布尔值,代表是否继续训练,若要继续训练则手动设置为True。另外在读取断点时还设置了schedulerepoch,这是因为现在的scheduler的更新策略往往跟当前的epoch是有关系的,例如随着epoch的增加学习率的梯度越来越小。

start_epoch = 0
resume = False
path_checkpoint = "checkpointfirst_7_epoch.pkl"#断点路径
if resume:
	checkpoint = torch.load(path_checkpoint)#加载断点
    model.load_state_dict(checkpoint['model_state_dict'])#加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])#加载优化器参数
    start_epoch = checkpoint['epoch']#设置开始的epoch
    scheduler.last_epoch = start_epoch#设置学习率的last_epoch

三、其他需注意的地方

  • 在我们使用预训练模型微调时,会先将预训练模型的前几层冻结着重训练后面自己添加的层。例如使用resnet101模型做微调时,先将前5层进行冻结,只训练最后一层全连接层。冻结时会使用下面的代码(这里以resnet101举例):
for child in model.children():
    ct += 1
    # print(ct,child)
    if ct < 5:
        for param in child.parameters():
            param.requires_grad = False

但当模型的参数存入断点文件时,是不会存储参数requires_grad 的。因此若要设置某些层不更新参数则需要在读取断点后执行相应设置。这样无论是重新训练还是继续训练某些层的requires_grad 都是满足需求的。

  • 即使在保存断点时这些参数是在GPU上,读取时仍然默认在CPU中,因此还是需要添加model = model.to(device)

无论是上述哪个点,都只要把相应操作的代码放在模型加载之后即可

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

红糖毛血旺

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值