在深度学习训练过程中,为及时掌握网络模型的训练状态、实时观察网络模型各参数的变化情况和实现训练过程中用户自定义的一些操作,MindSpore提供了回调机制(Callback)来实现上述功能。
Callback回调机制一般用在网络模型训练过程Model.train中,MindSpore的Model会按照Callback列表callbacks顺序执行回调函数,用户可以通过设置不同的回调类来实现在训练过程中或者训练后执行的功能。
Callback介绍
当聊到回调Callback的时候,大部分用户都会觉得很难理解,是不是需要堆栈或者特殊的调度方式,实际上我们简单的理解回调:
假设函数A有一个参数,这个参数是个函数B,当函数A执行完以后执行函数B,那么这个过程就叫回调。
Callback是回调的意思,MindSpore中的回调函数实际上不是一个函数而是一个类,用户可以使用回调机制来观察训练过程中网络内部的状态和相关信息,或在特定时期执行特定动作。
例如监控损失函数Loss、保存模型参数ckpt、动态调整参数lr、提前终止训练任务等。下面我们继续以手写体识别模型为例,介绍常见的内置回调函数和自定义回调函数。
常用的内置回调函数
MindSpore提供Callback能力,支持用户在训练/推理的特定阶段,插入自定义的操作。
ModelCheckpoint
用于保存训练后的网络模型和参数,方便进行再推理或再训练,MindSpore提供了ModelCheckpoint接口,一般与配置保存信息接口CheckpointConfig配合使用。
LossMonitor
用于监控训练或测试过程中的损失函数值Loss变化情况,可设置per_print_times控制打印Loss值的间隔。
训练场景下,LossMonitor监控训练的Loss值;边训练边推理场景下,监控训练的Loss值和推理的Metrics值
TimeMonitor
用于监控训练或测试过程的执行时间。可设置data_size控制打印执行时间的间隔。
自定义回调机制
MindSpore不仅有功能强大的内置回调函数,当用户有自己的特殊需求时,还可以基于Callback基类自定义回调类。
用户可以基于Callback基类,根据自身的需求,实现自定义Callback。Callback基类定义如下所示:
class Callback():
"""Callback base class"""
def on_train_begin(self, run_context):
"""Called once before the network executing."""
def on_train_epoch_begin(self, run_context):
"""Called before each epoch beginning."""
def on_train_epoch_end(self, run_context):
"""Called after each epoch finished."""
def on_train_step_begin(self, run_context):
"""Called before each step beginning."""
def on_train_step_end(self, run_context):
"""Called after each step finished."""
def on_train_end(self, run_context):
"""Called once after network training."""
回调机制可以把训练过程中的重要信息记录下来,通过把一个字典类型变量RunContext.original_args(),传递给Callback对象,使得用户可以在各个自定义的Callback中获取到相关属性,执行自定义操作,也可以自定义其他变量传递给RunContext.original_args()对象。
RunContext.original_args()中的常用属性有:
- epoch_num:训练的epoch的数量
- batch_num:一个epoch中step的数量
- cur_epoch_num:当前的epoch数
- cur_step_num:当前的step数
- loss_fn:损失函数
- optimizer:优化器
- train_network:训练的网络
- train_dataset:训练的数据集
- net_outputs:网络的输出结果
- parallel_mode:并行模式
- list_callback:所有的Callback函数
通过下面两个场景,我们可以增加对自定义Callback回调机制功能的了解。
自定义终止训练
实现在规定时间内终止训练功能。用户可以设定时间阈值,当训练时间达到这个阈值后就终止训练过程。
下面代码中,通过run_context.original_args方法可以获取到cb_params字典,字典里会包含前文描述的主要属性信息。
同时可以对字典内的值进行修改和添加,在begin函数中定义一个init_time对象传递给cb_params字典。每个数据迭代结束step_end之后会进行判断,当训练时间大于设置的时间阈值时,会向run_context传递终止训练的信号,提前终止训练,并打印当前的epoch、step、loss的值。
从上面的打印结果可以看出,当第3个epoch的第4673个step执行完时,运行时间到达了阈值并结束了训练。
自定义阈值保存模型
该回调机制实现当loss小于设定的阈值时,保存网络模型权重ckpt文件。
示例代码如下: