本文讲述MindSpore自定义在train的过程中实时验证的回调函数,继承callback类自定义evalcallback,可以用来设置:每隔几个epoch进行验证,实时输出model指定的metrics评价指标。然后通过实例化对象,将这个回调过程放入model.train()中的callback()中。
简单来讲就是输出对训练集的实时model.eval()测试结果。
class EvalCallBack(Callback):
def __init__(self, model, eval_dataset, epochs_to_eval, per_eval, dataset_sink_mode):
self.model = model
self.eval_dataset = eval_dataset
# epochs_to_eval是一个int数字,代表着:每隔多少个epoch进行一次验证
self.epochs_to_eval = epochs_to_eval
self.per_eval = per_eval
self.dataset_sink_mode = dataset_sink_mode
def epoch_end(self, run_context):
# 获取到现在的epoch数
cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num
# 如果达到进行验证的epoch数,则进行以下验证操作
if cur_epoch % self.epochs_to_eval == 0:
# 此处model设定的metrics是准确率Accuracy
acc = self.model.eval(self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
self.per_eval["epoch"].append(cur_epoch)
self.per_eval["acc"].append(acc["Accuracy"])
print("------------准确率为: {} ------------".format(acc["Accuracy"]))