训练网络
简介
learner通过lr_find方法找到合适的学习率,通过fit_one_cycle来训练网络,通过to_fp16来转换半精度。
Learner类
方法描述:
Learner(`data`:DataBunch, `model`:Module, `opt_func`:Callable=`'Adam'`, `loss_func`:Callable=`None`, `metrics`:Collection[Callable]=`None`, `true_wd`:bool=`True`, `bn_wd`:bool=`True`, `wd`:Floats=`0.01`, `train_bn`:bool=`True`, `path`:str=`None`, `model_dir`:str=`'models'`, `callback_fns`:Collection[Callable]=`None`, `callbacks`:Collection[Callback]=``, `layer_groups`:ModuleList=`None`)
训练model使用data,通过opt_func使得loss_func最小化。
此类中包含了训练方式、预测方式、学习率的差异学习以及模型的保存等等。
1)fit方法主要设置学习率与权重衰减
fit(`epochs`:int, `lr`:Union[float, Collection[float], slice]=`slice(None, 0.003, None)`, `wd`:Floats=`None`, `callbacks`:Collection[Callback]=`None`)
2)fit_one_cycle是一种呈周期调整学习率的方法,使用如下:
fit_one_cycle(`learn`:Learner, `cyc_len`:int, `max_lr`:Union[float, Collection[float], slice]=`slice(None, 0.003, None)`, `moms`:Point=`(0.95, 0.85)`, `div_factor`:float=`25.0`, `pct_start`:float=`0.3`, `wd`:float=`None`, `callbacks`:Optional[Collection[Callback]]=`None`, `kwargs`)
3)lr_find用来寻找合适的学习率,使用说明如下:
lr_find(`learn`:Learner, `start_lr`:Floats=`1e-07`, `end_lr`:Floats=`10`, `num_it`:int=`100`, `stop_div`:bool=`True`, `kwargs`:Any)
查看模型结果
具体见文档所述方法。
Recorder类
这一类主要是用来记录epoch、loss、opt等,我们可以很轻松的利用这一类中的方法画出损失的变化等图像,如:
learn.recorder.plot_lr(show_moms=True)
底层训练器
具体见文档,是构成learner的底层。