【AI设计模式】05-检查点模式(CheckPoints):如何定期存储模型?

作者:王磊
更多精彩分享,欢迎访问和关注:https://www.zhihu.com/people/wldandan

目录

在前一篇文章《数据处理-Eager模式》中分享了数据处理-Eager模式,那么在模型训练时,有哪些设计模式可以使用呢?在数据库领域,为了防止执行时间较长的存储过程失败重新执行,会将中间的过程状态以检查点的形式持续记录下来,每次失败时不需要重头执行,而是加载最近的检查点,继续执行,避免浪费时间。和存储过程类似,模型的训练时间会更长,如果缺乏一定的可靠性机制,过程中一旦失败,就需要重头开始训练,浪费时间较多。因此,需要实现类似的机制来保证可靠性问题,这种机制被称为检查点(CheckPoints)模式

AI设计模式总览

模式定义

检查点模式(CheckPoints)是指通过周期性(迭代/时间)的保存模型的完整状态,在模型训练失败时,可以从保存的检查点模型继续训练,以避免训练失败时每次都需要从头开始带来的训练时间浪费。检查点模式适用于模型训练时间长、训练需要提前结束、fine-tune等场景,也可以拓展到异常时的断点续训场景。

问题

  1. 训练耗时的网络在训练过程中失败,从头开始训练的代价高:对于层数比较深的神经网络,或者需要大规模训练数据的模型,训练的时间会很长。因为有更多的参数以及更多的数据样本需要处理。比如对于VGG16的网络,cifar-10的数据集,普通的NVIDIA显卡训练需要3-4小时;一旦过程中失败,需要重头开始训练,时间成本高。
  2. 训练时间越长,精度可能不发生变化,或者产生过拟合的现象。这种场景时,提前结束(early stopping)获得中间的模型状态收益会更高。
  3. Fine-Tune时,通常需要最终模型前面的一些模型状态进行基础上进行调优,这样可以更好的针对新数据进行训练,获得更好的泛化性。

解决方案

在每轮训练结束时,都保存当前的模型状态作为检查点,如果下轮训练失败时,可以从这个检查点模型继续训练。和训练完成导出的模型(以神经网络为例,最终的模型包含权重、激活函数以及隐藏层信息)相比,这个中间模型状态需要额外的轮、当前的批量计数等信息,以保证基于这个中间模型继续训练。通常这个中间模型被称为检查点(CheckPoints)。检查点的模型状态中通常不包括学习率,因为训练过程中它可能会动态调整。

如果在每个批量数据训练完,权重更新后都保存检查点,中间模型的数量和占用的空间会非常大。所以实践中通常会在每轮结束后保存检查点,或者保留最近的几个检查点。

案例

AI框架通常都提供了模型训练的检查点保存能力。在MindSpore中,通过训练API提供了ModelCheckPointCheckpointConfig模块来帮助开发者保存模型训练过程中的检查点。MindSpore提供了三种检查点保存策略,包括直接保存、周期保存(迭代次数或者训练时长)、和异常保存(在训练失败的异常情况下保存的策略)。

说明:检查点文件是一个二进制文件,存储了所有训练参数的值;且检查点的实现上采用了Protocol Buffers机制,与开发语言、平台无关,具有良好的可扩展性。

在此,我们重点介绍下如何在MindSpore中周期性保存模型状态、以及在异常情况下保存故障点的模型状态。

周期保存

1)迭代次数方式保存

下面的MindSpore代码片段展示了使用迭代次数配置检查点保存策略,以及在模型训练时通过回调的方式应用保存策略。训练开始后,会每隔1785个step保存一次检查点模型,并最多保留10个中间模型,模型的名称格式为checkpoint_lenet-1_1875.ckpt

 from mindspore.train.callback import ModelCheckpoint, CheckpointConfig 
    
 # 设置模型保存参数,设置模型保存的策略,如本例中设置最多保存10个checkpoints,每隔1875个step保存一次 
    
 config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10) 
    
 # 应用模型保存参数 
    
 ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) 
    
 #通过回调的方式配置在模型训练的过程中 
    
 model.train(epoch_size, ds_train, callbacks=[ckpoint_cb]) 
      

加载CheckPoint可以通过load_checkpointload_param_into_net方法来完成,如下面的代码,通过load_checkpoint方法从保存好的checkpoint中加载网络的参数,再通过load_param_into_net将参数导入到具体的网络实例中,方便后面续训或者评估。

from mindspore import load_checkpoint, load_param_into_net
# 加载已经保存的用于测试的模型
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
# 加载参数到网络中
load_param_into_net(net, param_dict)

完整的代码可以参考[1]中的案例。

2)周期时间方式保存

时间策略提供了按照秒和分钟配置参数,如下面的代码,每隔30秒保存一个CheckPoint文件,每隔3分钟保留一个CheckPoint文件。

from mindspore import CheckpointConfig

# 每隔30秒保存一个CheckPoint文件,每隔3分钟保留一个CheckPoint文件
config_ck = CheckpointConfig(save_checkpoint_seconds=30, keep_checkpoint_per_n_minutes=3)

异常保存

如果模型较大,通常会减少梳理保留的检查点模型,间隔的时间会拉长。如盘古大模型的检查点保存间隔在4-5小时,如果在两个检查点之间失败,那么从上个检查点重新训练的时间损失会比较大。MindSpore在1.7版本扩展了检查点功能,提供断点续训能力,保证在训练异常时触发检查点,保证下次可以从发生故障时的模型状态继续训练,训练时间无损失。引入断点续训功能,只需在策略配置时增加“exception_save=True”的参数即可。

from mindspore import ModelCheckpoint, CheckpointConfig
# 配置断点续训功能开启
config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10, exception_save=True)

总结

检查点(CheckPoints)模式最大的作用在于保证了模型训练的可靠性,同时也可以让开发者更容易的做早停。断点续训能力对于大模型的价值较大,异常状态下续训无时间损失,检查点模式也有利于转移学习时做fine-tune,这也是我们下一个要介绍的模式。

参考资料

[1] MindSpore完整案例:https://mindspore.cn/tutorials/zh-CN/r1.7/beginner/quick_start.html

[2] MindSpore模型保存:https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/advanced/train/save.ipynb

[3] 机器学习设计模式:https://www.oreilly.com/library/vie

说明:严禁转载本文内容,否则视为侵权。 

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值