1. 功能描述:
MindSpore训练模型时,实现保存最优模型。
2. 实现保存最优模型功能简介:
在面对复杂网络时,往往需要进行几十甚至几百次的epoch训练。在训练之前,很难掌握在训练到第几个epoch时,模型的精度能达到满足要求的程度,所以经常会采用一边训练的同时,在相隔固定epoch的位置对模型进行精度验证,并保存相应的模型,等训练完毕后,通过查看对应模型精度的变化就能迅速地挑选出相对最优的模型。
流程如下:
1) 定义回调函数EvalCallBack,实现同步进行训练和验证。
2) 定义训练网络并执行。
3) 将不同epoch下的模型精度绘制出折线图并挑选最优模型。
3. 原因分析:
MindSpore在训练模型时,保存最后一个ckpt可能精度不达标。
4. 解决方案:
apply_eval函数,用来验证模型的精度。定义回调函数EvalCallBack:
模型验证
def apply_eval(eval_param):
eval_model = eval_param['model']
eval_ds = eval_param['dataset']
metrics_name = eval_param['metrics_name']
res = eval_model.eval(eval_ds)
return res[metrics_name]
我们自定义一个数据收集的回调类EvalCallBack,用于实现下面两种信息:
4.1 训练过程中,每一个epoch结束之后,训练集的损失值和验证集的模型精度。
4.2 保存精度最高的模型。
class EvalCallBack(Callback):
"""
回调类,获取训练过程中模型的信息
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"