前一篇博文介绍了如何在Fast AI
框架下构建学习器。那么接下来就是训练模型。在训练过程中,往往需要在一些特定的阶段做一些特殊的操作,如在每训练一定数目的batch
后,对学习速率进行调整;或是每训练一个epoch
后,需要记录网络在验证(validation
)数据集上的某些输出,如在validation
上的loss
,或者一些其他想观察的指标(metrics
)。这些都是通过Fast AI
的回调(Callbacks
)系统实现的。本篇博客将对相关的内容进行具体的介绍,并以Fast AI
中定义的回调类为例进行说明。
一、Fast AI
回调系统
一般而言,网络的训练主要由两层循环构成。外层循环的每一次执行,都要遍历一遍整个训练数据集,即循环单位为epoch
;而内层循环的循环单位为batch
,即每次执行会对一个batch
的数据做相关操作。而在内层循环中,每次执行又由若干步骤组成:网络的前向计算(forward()
)、损失值的计算(loss()
)、梯度的计算(backward()
)、网络参数的更新(opt.step()
)。图示如下:
![](https://img-blog.csdnimg.cn/20200122205147710.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3N1cmVkaWVk,size_16,color_FFFFFF,t_70)
如前所述,有时需要在训练流程中的特定阶段做一些特殊操作。如在训练开始时、每个epoch
开始时等等。暂且把这些特殊操作称为回调(Callbacks
),而这些能够插入回调功能的特定阶段则可称为回调槽(我自己取的名字,不要认真)。
在Fast AI
中,用于实现具体回调功能的基类为Callback
(文档链接,代码见fastai.callback.py
文件)。该类定义了若干回调函数的接口,对应于整个训练流程的10
个回调槽。图示如下:
![](https://img-blog.csdnimg.cn/20200122205331705.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3N1cmVkaWVk,size_16,color_FFFFFF,t_70)
所有用于实现特定操作的回调类必须派生自Callback
,实现一个或若干个回调槽对应的功能函数。此外,Fast AI
中另一个和回调功能有关的重要类是CallbackHandler
(代码见fastai.callback.py
),该类提供了对各个具体回调类的管理功能。这儿所谓的管理功能,是指Leaner
对象的训练函数fit()
就是通过CallbackHandler
对各个Callback
进行调用的。如下图所示:
![](https://img-blog.csdnimg.cn/20200122205440262.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3N1cmVkaWVk,size_16,color_FFFFFF,t_70)
由上可见,CallbackHandler
类的相关的成员函数与回调槽功能函数接口的名称是一致的,这样,CallbackHandler
类就为学习器Learner
对象提供了调用回调功能的统一接口。除此之外,CallbackHandler
类还会维持一个**state_dict
字典**,该字典会被传给各个Callback
的回调槽的功能函数,并使用返回的值更新相应字段(也就意味着一个Callback
类的回调槽功能函数的返回值为一个字典或者返回None
,参见CallbackHandler._call_and_update()
函数的实现)。state_dict
会被用于fit()
函数中的各种条件的判断与流程控制。
二、Callback
示例: Recorder
、LRFinder
和OneCycleScheduler
这里所介绍三个示例,均与学习器的训练过程相关。其中Recorder
是在创建Learner
对象时默认添加的回调类,用于记录训练流程中学习器的状态。而后两者都与学习速率(lr
,Learning Rate
)的调整有关:LRFinder
是关于学习速率查找的;OneCycleScheduler
是关于训练过程中对lr
进行调整以使模型更快收敛的。
这三个回调类均继承自LearnerCallback
类。该类主要提供了在需要进行Callback
的序列化时需要保存的参数,如保存学习器Learner
对象时,需要同时保存与学习器相关的Callback
的状态。这主要通过exclude
(需要排除的状态名)和not_min
(非最小集的状态名)两个参数进行设置。另外,LearnerCallback
还会将回调类注册为所关联的Learner
对象的属性,属性名为回调类的类名的蛇形形式(snake case
),如回调类名称为MyClass
,那么在Learner
对象中相应的属性名为my_class
。不仅如此,LearnerCallback
还覆写了__getattr__()
函数,使得通过LearnerCallback
对象对一些属性进行的访问,会被重定向到其所依附的Learner
对象的相应属性上。
1. Recorder
类
Recorder
类定义在fastai.basic_train.py
文件中。该类用于记录训练过程中Learner
对象的状态,包括epoch
、loss
、opt
(优化器的状态)、metric
等。实际上,在训练过程中Fast AI
所展示的监控信息(见下图),就是由该类对象记录的。
该类主要涉及如下回调槽的功能:
- (1)
on_train_begin()
: 初始化一些列表,用于存储loss
、val_losses
、lrs
、moms
、metrics
、nb_batches
等参数。 - (2)
on_epoch_begin()
: 记录本次epoch
开始的时间。 - (3)
on_batch_end()
: 将opt.lr
和opt.mom
存进对应的列表。 - (4)
on_backward_begin()
: 将平滑过的loss
存入对应的列表。 - (5)
on_epoch_end()
: 记录本轮迭代的batch
数目,存储在Validation
上的metrics
。
该类提供了如下功能:
- (1) 在使用学习速率查找功能时,绘制损失随学习速率变化的曲线。
其中plot(skip_start:int=10, skip_end:int=5, suggestion:bool=False, return_fig:bool=None, **kwargs) → Optional[Figure]
skip_start
和skip_end
表示要跳过起始和终止的若干次损失值异常的循环。suggestion
表示是否推荐lr
。 - (2) 绘制
train
和validation
过程中的损失曲线plot_losses(skip_start:int=0, skip_end:int=0, return_fig:bool=None) → Optional[Figure]
- (3) 绘制
lr
曲线
其中plot_lr(show_moms=False, skip_start:int=0, skip_end:int=0, return_fig:bool=None) → Optional[Figure]
show_moms
表示同时绘制动量曲线。 - (4) 绘制性能指标曲线
plot_metrics(skip_start:int=0, skip_end:int=0, return_fig:bool=None) → Optional[Figure]
2. LRFinder
类(文档链接)
Learner
对象中用于搜索学习速率的函数为lr_find()
,定义在fastai.train.py
文件中。该函数的主要功能是:对网络训练若干个batch
,每次迭代时按等比序列更新lr
,记录网络输出的损失值。其中lr
的默认搜索区间是[1E-7, 10]
,可通过start_lr
和end_lr
来设置。batch
的数目预设为100
,可通过num_it
参数修改。但当损失值发散时,就会停止训练。发散的判定方法是:以平滑后的损失值为监控变量,如果当前值大于最小值的4倍,即认为损失值发散。(关于学习速率查找的问题可参考这篇博客。)
而该函数的主要功能是通过LRFinder
回调类来实现的。LRFinder
定义在fastai.callbacks.lr_finder.py
文件中,主要涉及三个回调槽的功能:
- (1)
on_train_begin
: 主要做一些初始化的工作,如保存网络的初始状态为tmp
,设置不进行validataion
,这是通过返回字典值{"skip_validate":True}
来实现的。 - (2)
on_batch_end
: 主要是调整学习速率。该函数会返回两个是否结束查找过程的控制变量:stop_epoch
和stop_training
。lr_find()
指定了batch
的迭代次数,而非epoch
数目。因此,在LRFinder
初始化学习速率变化的scheduler
时,也是使用的iterations
.因此,当迭代次数预定值时,即会设置{"stop_epoch": True, "stop_training": True}
。另外当本次迭代的loss
大于最优loss
的四倍时(此处的loss
均是平滑后的结果),也会使得训练停止。 - (3)
on_train_end()
: 使用一开始存储的网络状态tmp
还原网络。这意味着在训练过程中也可使用lr_find()
函数,而不会造成网络状态的不连续。
3. OneCycleScheduler
类([文档链接](https://docs.fast.ai/callbacks.one_cycle.html)
OneCycleScheduler
(代码见fastai.callbacks.one_cycle.py
)主要用于Learner
对象的fit_one_cycle()
函数(该函数定义在fastai.train.py
文件中)中,功能是对整个训练流程中的lr
按照特定的曲线进行调整。
fit_one_cycle()
的定义如下:
fit_one_cycle(learn:Learner, # 学习器
cyc_len:int, # 一个cycle的epoch数目
max_lr:Union[Floats,slice]=defaults.lr, # lr的最大值,默认值为0.003
moms:Tuple[float,float]=(0.95,0.85), # 动量参数
div_factor:float=25., # 在lr上升阶段,max_lr与起始时的lr的比值
pct_start:float=0.3, # lr上升阶段的迭代次数的占比(以batch计算)
final_div:float=None, # 在lr下降阶段,max_lr与终止时的lr的比值
wd:float=None, callbacks:Optional[CallbackList]=None, tot_epochs:int=None, start_epoch:int=None)->None:
相关参数及lr
变化曲线如下图所示。
![](https://img-blog.csdnimg.cn/20200122205857756.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3N1cmVkaWVk,size_16,color_FFFFFF,t_70)
- (1)
on_train_begin
: 初始化,主要是初始化两个Scheduler
的列表,分别对应lr
和mom
。每个列表由两个Scheduler
组成,分别对应训练的两个阶段。 - (2)
on_batch_end
: 进行lr
和mom
的调整。由于是两阶段的训练,所以OneCycleScheduler
会维护一个索引,用于标识训练进行到了哪一步。 - (3)
on_epoch_end
: 判断训练是否结束,并返回stop_training
标志。
三、其他Callback
(文档链接)
1.ShowGraph
类
派生自LearnerCallback
类,功能是在训练过程中绘制Learner
对象的metrics
的曲线。使用方法如下:
learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph)
2.GradientClipping
类
派生自LearnerCallback
类,功能是在训练过程中对梯度进行截断。使用方法如下:
learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=partial(GradientClipping, clip=0.1))
3.BnFreeze
类
派生自LearnerCallback
类,功能是不更新BN
层参数的滑动平均值。使用方法如下:
learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze)
4.AccumulateScheduler
类
派生自LearnerCallback
类,功能是经过若干个batch
后,再进行网络参数的更新,但要求在计算损失函数时,使用的是累积的方法。另外,Fast AI
未对BN
层做相应的处理。使用方法:
learn = cnn_learner(data, resnet18, metrics=accuracy, loss_func=CrossEntropyFlat(reduction='sum'), callback_fns=partial(AccumulateScheduler, n_step=16))