第六篇 FastAI的回调系统

前一篇博文介绍了如何在Fast AI框架下构建学习器。那么接下来就是训练模型。在训练过程中,往往需要在一些特定的阶段做一些特殊的操作,如在每训练一定数目的batch后,对学习速率进行调整;或是每训练一个epoch后,需要记录网络在验证(validation)数据集上的某些输出,如在validation上的loss,或者一些其他想观察的指标(metrics)。这些都是通过Fast AI的回调(Callbacks)系统实现的。本篇博客将对相关的内容进行具体的介绍,并以Fast AI中定义的回调类为例进行说明。

一、Fast AI回调系统

一般而言,网络的训练主要由两层循环构成。外层循环的每一次执行,都要遍历一遍整个训练数据集,即循环单位为epoch;而内层循环的循环单位为batch,即每次执行会对一个batch的数据做相关操作。而在内层循环中,每次执行又由若干步骤组成:网络的前向计算(forward())、损失值的计算(loss())、梯度的计算(backward())、网络参数的更新(opt.step())。图示如下:

图 1. 网络训练流程

如前所述,有时需要在训练流程中的特定阶段做一些特殊操作。如在训练开始时、每个epoch开始时等等。暂且把这些特殊操作称为回调(Callbacks),而这些能够插入回调功能的特定阶段则可称为回调槽(我自己取的名字,不要认真)。

Fast AI中,用于实现具体回调功能的基类为Callback文档链接,代码见fastai.callback.py文件)。该类定义了若干回调函数的接口,对应于整个训练流程的10个回调槽。图示如下:

图 2. 训练流程的回调槽

所有用于实现特定操作的回调类必须派生自Callback,实现一个或若干个回调槽对应的功能函数。此外,Fast AI中另一个和回调功能有关的重要类是CallbackHandler(代码见fastai.callback.py),该类提供了对各个具体回调类的管理功能。这儿所谓的管理功能,是指Leaner对象的训练函数fit()就是通过CallbackHandler对各个Callback进行调用的。如下图所示:

图 3. Learner通过CallbackHandler对象调用各个回调功能

由上可见,CallbackHandler类的相关的成员函数与回调槽功能函数接口的名称是一致的,这样,CallbackHandler类就为学习器Learner对象提供了调用回调功能的统一接口。除此之外,CallbackHandler类还会维持一个**state_dict字典**,该字典会被传给各个Callback的回调槽的功能函数,并使用返回的值更新相应字段(也就意味着一个Callback类的回调槽功能函数的返回值为一个字典或者返回None,参见CallbackHandler._call_and_update()函数的实现)。state_dict会被用于fit()函数中的各种条件的判断与流程控制。

二、Callback示例: RecorderLRFinderOneCycleScheduler

这里所介绍三个示例,均与学习器的训练过程相关。其中Recorder是在创建Learner对象时默认添加的回调类,用于记录训练流程中学习器的状态。而后两者都与学习速率(lrLearning Rate)的调整有关:LRFinder是关于学习速率查找的;OneCycleScheduler是关于训练过程中对lr进行调整以使模型更快收敛的。

这三个回调类均继承自LearnerCallback类。该类主要提供了在需要进行Callback的序列化时需要保存的参数,如保存学习器Learner对象时,需要同时保存与学习器相关的Callback的状态。这主要通过exclude(需要排除的状态名)和not_min(非最小集的状态名)两个参数进行设置。另外,LearnerCallback还会将回调类注册为所关联的Learner对象的属性,属性名为回调类的类名的蛇形形式(snake case),如回调类名称为MyClass,那么在Learner对象中相应的属性名为my_class。不仅如此,LearnerCallback还覆写了__getattr__()函数,使得通过LearnerCallback对象对一些属性进行的访问,会被重定向到其所依附的Learner对象的相应属性上。

1. Recorder

Recorder类定义在fastai.basic_train.py文件中。该类用于记录训练过程中Learner对象的状态,包括epochlossopt(优化器的状态)、metric等。实际上,在训练过程中Fast AI所展示的监控信息(见下图),就是由该类对象记录的。

图 4. 训练过程中所展示的监控信息

该类主要涉及如下回调槽的功能:

  • (1) on_train_begin(): 初始化一些列表,用于存储lossval_losseslrsmomsmetricsnb_batches等参数。
  • (2) on_epoch_begin(): 记录本次epoch开始的时间。
  • (3) on_batch_end(): 将opt.lropt.mom存进对应的列表。
  • (4) on_backward_begin(): 将平滑过的loss存入对应的列表。
  • (5) on_epoch_end(): 记录本轮迭代的batch数目,存储在Validation上的metrics

该类提供了如下功能:

  • (1) 在使用学习速率查找功能时,绘制损失随学习速率变化的曲线。
    plot(skip_start:int=10, skip_end:int=5, suggestion:bool=False, return_fig:bool=None, **kwargs) → Optional[Figure]
    
    其中skip_startskip_end表示要跳过起始和终止的若干次损失值异常的循环。suggestion表示是否推荐lr
  • (2) 绘制trainvalidation过程中的损失曲线
    plot_losses(skip_start:int=0, skip_end:int=0, return_fig:bool=None) → Optional[Figure]
    
  • (3) 绘制lr曲线
    plot_lr(show_moms=False, skip_start:int=0, skip_end:int=0, return_fig:bool=None) → Optional[Figure]
    
    其中show_moms表示同时绘制动量曲线。
  • (4) 绘制性能指标曲线
    plot_metrics(skip_start:int=0, skip_end:int=0, return_fig:bool=None) → Optional[Figure]
    
2. LRFinder类(文档链接

Learner对象中用于搜索学习速率的函数为lr_find(),定义在fastai.train.py文件中。该函数的主要功能是:对网络训练若干个batch,每次迭代时按等比序列更新lr,记录网络输出的损失值。其中lr的默认搜索区间是[1E-7, 10],可通过start_lrend_lr来设置。batch的数目预设为100,可通过num_it参数修改。但当损失值发散时,就会停止训练。发散的判定方法是:以平滑后的损失值为监控变量,如果当前值大于最小值的4倍,即认为损失值发散。(关于学习速率查找的问题可参考这篇博客。)

而该函数的主要功能是通过LRFinder回调类来实现的。LRFinder定义在fastai.callbacks.lr_finder.py文件中,主要涉及三个回调槽的功能:

  • (1) on_train_begin: 主要做一些初始化的工作,如保存网络的初始状态为tmp,设置不进行validataion,这是通过返回字典值{"skip_validate":True}来实现的。
  • (2) on_batch_end: 主要是调整学习速率。该函数会返回两个是否结束查找过程的控制变量:stop_epochstop_traininglr_find()指定了batch的迭代次数,而非epoch数目。因此,在LRFinder初始化学习速率变化的scheduler时,也是使用的iterations.因此,当迭代次数预定值时,即会设置{"stop_epoch": True, "stop_training": True}。另外当本次迭代的loss大于最优loss的四倍时(此处的loss均是平滑后的结果),也会使得训练停止。
  • (3) on_train_end(): 使用一开始存储的网络状态tmp还原网络。这意味着在训练过程中也可使用lr_find()函数,而不会造成网络状态的不连续。
3. OneCycleScheduler类([文档链接](https://docs.fast.ai/callbacks.one_cycle.html)

OneCycleScheduler(代码见fastai.callbacks.one_cycle.py)主要用于Learner对象的fit_one_cycle()函数(该函数定义在fastai.train.py文件中)中,功能是对整个训练流程中的lr按照特定的曲线进行调整。

fit_one_cycle()的定义如下:

fit_one_cycle(learn:Learner, # 学习器
    cyc_len:int, # 一个cycle的epoch数目
    max_lr:Union[Floats,slice]=defaults.lr, # lr的最大值,默认值为0.003
    moms:Tuple[float,float]=(0.95,0.85), # 动量参数
    div_factor:float=25., # 在lr上升阶段,max_lr与起始时的lr的比值
    pct_start:float=0.3, # lr上升阶段的迭代次数的占比(以batch计算)
    final_div:float=None, # 在lr下降阶段,max_lr与终止时的lr的比值
    wd:float=None, callbacks:Optional[CallbackList]=None, tot_epochs:int=None, start_epoch:int=None)->None:

相关参数及lr变化曲线如下图所示。

图 6. One Cycle中lr的变化曲线
其中`OneCycleScheduler`定义在`fastai.callbacks.one_cycle.py`文件中,主要涉及三个回调槽的功能:
  • (1) on_train_begin: 初始化,主要是初始化两个Scheduler的列表,分别对应lrmom。每个列表由两个Scheduler组成,分别对应训练的两个阶段。
  • (2) on_batch_end: 进行lrmom的调整。由于是两阶段的训练,所以OneCycleScheduler会维护一个索引,用于标识训练进行到了哪一步。
  • (3) on_epoch_end: 判断训练是否结束,并返回stop_training标志。

三、其他Callback文档链接

1.ShowGraph

派生自LearnerCallback类,功能是在训练过程中绘制Learner对象的metrics的曲线。使用方法如下:

learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph)
2.GradientClipping

派生自LearnerCallback类,功能是在训练过程中对梯度进行截断。使用方法如下:

learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=partial(GradientClipping, clip=0.1))
3.BnFreeze

派生自LearnerCallback类,功能是不更新BN层参数的滑动平均值。使用方法如下:

learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze)
4.AccumulateScheduler

派生自LearnerCallback类,功能是经过若干个batch后,再进行网络参数的更新,但要求在计算损失函数时,使用的是累积的方法。另外,Fast AI未对BN层做相应的处理。使用方法:

learn = cnn_learner(data, resnet18, metrics=accuracy, loss_func=CrossEntropyFlat(reduction='sum'), callback_fns=partial(AccumulateScheduler, n_step=16))
一些有用的链接
  • 8
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值