pytorch实现checkpoint

checkpoint (pytorch)

深度学习模型在训练中需要保存参数,checkpoint就是在每个训练周期后保存模型参数快照的术语。如同打游戏时,需要保存关卡一样,随时通过加载保存的文件恢复游戏。
深度学习模型的训练通常需要很长的时间,为了不丢失训练进度,建议在每个时期对模型的参数实施checkpoint,但前提是它是该时间最佳参数。

使用pytorch创建checkpoint

保存 checkpoint

def save_checkpoint(state, is_best, checkpoint):
   """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves
   checkpoint + 'best.pth.tar'
   Args:
       state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict
       is_best: (bool) True if it is the best model seen till now
       checkpoint: (string) folder where parameters are to be saved
   """
   filepath = os.path.join(checkpoint, 'last.pth.tar')
   if not os.path.exists(checkpoint):
       print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint))
       os.mkdir(checkpoint)
   else:
       print("Checkpoint Directory exists! ")
   torch.save(state, filepath)
   if is_best:
       shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar'))

加载checkpoint

def load_checkpoint(checkpoint, model, optimizer=None):
    """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of
    optimizer assuming it is present in checkpoint.
    Args:
        checkpoint: (string) filename which needs to be loaded
        model: (torch.nn.Module) model for which the parameters are loaded
        optimizer: (torch.optim) optional: resume optimizer from checkpoint
    """
    if not os.path.exists(checkpoint):
        raise("File doesn't exist {}".format(checkpoint))
    checkpoint = torch.load(checkpoint)
    model.load_state_dict(checkpoint['state_dict'])

    if optimizer:
        optimizer.load_state_dict(checkpoint['optim_dict'])

    return checkpoint

在测试评估阶段,需要完成以下事情:

  1. 是否从checkpoint恢复
if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

  1. 训练-评估每次迭代并跟踪最佳结果
  2. 如果训练迭代产生最佳结果,则保存权重,或仅保存每次迭代
best_val_acc = 0.0 # track best validation accuracy (or loss) outside of loop

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model, optimizer, loss_fn, train_dataloader, metrics, params)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate(model, loss_fn, val_dataloader, metrics, params)

        # Get validation accuracy and track best validation accuracy
        val_acc = val_metrics['accuracy']
        is_best = val_acc>=best_val_acc

        # Save weights if validation accuracy is best
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict' : optimizer.state_dict()},
                               is_best=is_best,
                               checkpoint=model_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            best_val_acc = val_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir, "metrics_val_best_weights.json")
utils.save_dict_to_json(val_metrics, best_json_path)

参考:
https://nusit.nus.edu.sg/services/hpc-newsletter/deep-learning-best-practices-checkpointing-deep-learning-model-training/

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值