pytorch训练中断后,如何在之前的断点处继续训练

我们在训练模型的时候经常出现各种问题导致训练中断,比方说断电,或者关机之类的导致电脑系统关闭,从而将模型训练中断,那么如何在模型中断后,能够保留之前的训练结果不被丢失,同时又可以继续之前的断点处继续训练?

首先在代码离需要保存模型,比方说我们模型设置训练5000轮,那么我们可以选择每100轮保存一次模型,这样的话,在训练的过程中就能保存下100,200,300.。。。等轮数时候的模型,那么当模型训练到400轮的时候突然训练中断,那么我们就可以通过加载400轮的参数来进行继续训练,其实这个过程就类似在预训练模型的基础上进行训练。下面简单粗暴上代码:

1、保存模型

torch.save(checkpoint, checkpoint_path)

其中checkpoint其实保存的就是模型的一些参数,比方说下面这种字典形式的保存所需的模型参数:

checkpoint = {
    'model': model_state_dict,
    'generator': generator_state_dict,
    'opt': model_opt,
    'optim': optim,
}

checkpoint_path则是表示保存的模型

checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)

save_checkpoint_steps是保存的间隔轮数,step是保存的轮数,比方说save_checkpoint_steps=100,那么step的取值就是100,200,300,400等,下面的代码解释step的取值由来。

if step % self.save_checkpoint_steps != 0:
    return
chkpt, chkpt_name = self._save(step)

其中_save函数就是实现了前面checkpoint的内容的保存。

模型的保存设置就此结束。

2、模型的加载

假如此时模型训练中断了,我们得在代码里设置一个参数,这个参数用来查找确定当前路径下是否有已存在得模型。

# 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        generator.load_state_dict(checkpoint['generator'])
        start_epoch = checkpoint['model_opt']
        optim=checkpoint['optim']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存模型,将从头开始训练!')

或者设置一个变量train_from,若赋值已有模型得路径,则继续训练;若为None,那么从头训练。这块代码既可以用于训练中断,又可以用于使用预训练模型。

if opt.train_from:#是否存在预训练模型
    logger.info('Loading checkpoint from %s' % opt.train_from)
    checkpoint = torch.load(opt.train_from)#加载预训练模型的检查点
    model_opt = checkpoint['opt']
else:
    checkpoint = None
    model_opt = opt

加油,come on!

  • 25
    点赞
  • 205
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论
对于 Mask R-CNN 的接续训练,你需要遵循以下步骤: 1. 数据准备:收集足够数量的标注数据,包括图像和相应的标注信息,例如物体边界框和掩码。确保标注数据与初始训练数据集保持一致。 2. 模型配置:根据你的需求,配置 Mask R-CNN 模型的超参数,如学习率、批次大小、迭代次数等。你可以使用开源的 Mask R-CNN 实现,如 Detectron2 或 mmdetection,根据自己的需求进行修改。 3. 模型初始化:使用已经训练好的 Mask R-CNN 模型作为初始模型。你可以使用预训练的权重,也可以使用之前训练过的模型。 4. 训练过程:在接续训练中,你需要加载初始模型的权重,并使用新的数据集进行迭代训练。通常情况下,你可以选择冻结初始模型的部分层,只更新与新数据集相关的层,以加快训练速度。 5. 学习率调整:可以根据训练过程中的性能表现,适时调整学习率。常见的策略包括学习率衰减和学习率预热。 6. 评估与调优:在每个训练周期结束后,使用验证集对模型进行评估。根据评估结果,调整模型的超参数或训练策略,以提升模型性能。 7. 迭代训练:根据需要,可以进行多轮的迭代训练,直到模型达到满意的性能水平。 请注意,接续训练需要更多的计算资源和训练时间,因此在进行接续训练之前,请确保你具备足够的计算资源和时间。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序小K

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值