保存student模型在验证集上的最佳checkpoint模型
做法:注册student_checkpoint_hook
- 继承CheckpointHook类,重写_get_metric_score方法和_save_best_checkpoint方法
- _get_metric_score能在metric中获得确定评估方法的数值
- _save_best_checkpoint方法根据prefix标记,判断是否是student模型评估阶段,如果是,更新self.is_better_than字典中的key和val值
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Optional, Dict, Callable
import warnings
from mmengine.registry import HOOKS
from mmengine.hooks import CheckpointHook
from mmengine.dist import is_main_process
@HOOKS.register_module()
class Student_CheckpointHook(CheckpointHook):
"""继承mmengine的CheckpointHook, 用以适应蒸馏场景下,保存验证集上student模型最好的评估指标
Args:
CheckpointHook (_type_): _description_
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
if self.save_best is not None:
if len(self.key_indicators) == 1:
self.best_student_ckpt_path: Optional[str] = None
else:
self.best_student_ckpt_path_dict: Dict = dict()
def _get_metric_score(self, metrics, key_indicator):
prefix = list(metrics.keys())[0]
if 'student' in prefix:
prefix = 'student.'
else:
prefix = 'teacher.'
eval_res = OrderedDict()
if metrics is not None:
eval_res.update(metrics)
if len(eval_res) == 0:
warnings.warn(
'Since `eval_res` is an empty dict, the behavior to save '
'the best checkpoint will be skipped in this evaluation.')
return None
return prefix, eval_res[prefix + key_indicator]
# 如何确定,什么时候是student,还是teacher呢,应该有一个变量显示这个数据的
def _save_best_checkpoint(self, runner, metrics) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
Args:
runner (Runner): The runner of the training process.
metrics (dict): Evaluation results of all metrics.
"""
if not self.save_best:
return
if self.by_epoch:
ckpt_filename = self.filename_tmpl.format(runner.epoch)
cur_type, cur_time = 'epoch', runner.epoch
else:
ckpt_filename = self.filename_tmpl.format(runner.iter)
cur_type, cur_time = 'iter', runner.iter
# handle auto in self.key_indicators and self.rules before the loop
if 'auto' in self.key_indicators:
self._init_rule(self.rules, [list(metrics.keys())[0]])
# save best logic
# get score from messagehub
for key_indicator, rule in zip(self.key_indicators, self.rules):
prefix, key_score = self._get_metric_score(metrics, key_indicator)
if prefix == 'teacher.':
# 不保存训练过程中的teacher模型
continue
else:
# 添加prefix 标记
key_indicator = prefix + key_indicator
# 更新self.is_better_than 字典中的key
if rule is not None:
self.is_better_than[key_indicator] = self.rule_map[rule]
if len(self.key_indicators) == 1:
best_score_key = 'best_score'
runtime_best_ckpt_key = 'best_ckpt'
best_student_ckpt_path = self.best_student_ckpt_path
else:
best_score_key = f'best_score_{key_indicator}'
runtime_best_ckpt_key = f'best_ckpt_{key_indicator}'
# best_ckpt_path = self.best_ckpt_path_dict[key_indicator]
best_student_ckpt_path = self.best_student_ckpt_path_dict[key_indicator]
if best_score_key not in runner.message_hub.runtime_info:
best_score = self.init_value_map[rule]
else:
best_score = runner.message_hub.get_info(best_score_key)
if key_score is None or not self.is_better_than[key_indicator](
key_score, best_score):
continue
best_score = key_score
runner.message_hub.update_info(best_score_key, best_score)
if best_student_ckpt_path and \
self.file_client.isfile(best_student_ckpt_path) and \
is_main_process():
self.file_client.remove(best_student_ckpt_path)
runner.logger.info(
f'The previous best checkpoint {best_student_ckpt_path} '
'is removed')
best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}'
if len(self.key_indicators) == 1:
self.best_student_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name)
runner.message_hub.update_info(runtime_best_ckpt_key,
self.best_student_ckpt_path)
else:
self.best_student_ckpt_path_dict[
key_indicator] = self.file_client.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name)
runner.message_hub.update_info(
runtime_best_ckpt_key,
self.best_student_ckpt_path_dict[key_indicator])
runner.save_checkpoint(
self.out_dir,
filename=best_ckpt_name,
file_client_args=self.file_client_args,
save_optimizer=False,
save_param_scheduler=False,
by_epoch=False,
backend_args=self.backend_args)
runner.logger.info(
f'The best checkpoint with {best_score:0.4f} {key_indicator} '
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
- 模型评估阶段得到的结果形式为student.和teacher.的形式,如student.mIoU
for key, value in metrics.items():
student_key = 'student.' + key
teacher_key = 'teacher.' + key
student_metrics[student_key] = value
self.runner.message_hub.log_scalars.pop(f'val/{teacher_key}', None)
self.runner.call_hook('after_val_epoch', metrics=student_metrics) # 此处call hook