mmrazor-distill部分,重写checkpointhook

保存student模型在验证集上的最佳checkpoint模型

做法:注册student_checkpoint_hook

  • 继承CheckpointHook类,重写_get_metric_score方法和_save_best_checkpoint方法
  • _get_metric_score能在metric中获得确定评估方法的数值
  • _save_best_checkpoint方法根据prefix标记,判断是否是student模型评估阶段,如果是,更新self.is_better_than字典中的key和val值
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Optional, Dict, Callable
import warnings

from mmengine.registry import HOOKS
from mmengine.hooks import CheckpointHook
from mmengine.dist import is_main_process


@HOOKS.register_module()
class Student_CheckpointHook(CheckpointHook):
    """继承mmengine的CheckpointHook, 用以适应蒸馏场景下,保存验证集上student模型最好的评估指标

    Args:
        CheckpointHook (_type_): _description_
    """
    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        
        if self.save_best is not None:
            if len(self.key_indicators) == 1:
                self.best_student_ckpt_path: Optional[str] = None
            else:
                self.best_student_ckpt_path_dict: Dict = dict()

        
    def _get_metric_score(self, metrics, key_indicator):
        
        prefix = list(metrics.keys())[0]
        if 'student' in prefix:
            prefix = 'student.'
        else:
            prefix = 'teacher.'
        
        eval_res = OrderedDict()
        if metrics is not None:
            eval_res.update(metrics)
            

        if len(eval_res) == 0:
            warnings.warn(
                'Since `eval_res` is an empty dict, the behavior to save '
                'the best checkpoint will be skipped in this evaluation.')
            return None

        return  prefix, eval_res[prefix + key_indicator]
    
    # 如何确定,什么时候是student,还是teacher呢,应该有一个变量显示这个数据的
    
    
    
    def _save_best_checkpoint(self, runner, metrics) -> None:
        """Save the current checkpoint and delete outdated checkpoint.

        Args:
            runner (Runner): The runner of the training process.
            metrics (dict): Evaluation results of all metrics.
        """
        if not self.save_best:
            return

        if self.by_epoch:
            ckpt_filename = self.filename_tmpl.format(runner.epoch)
            cur_type, cur_time = 'epoch', runner.epoch
        else:
            ckpt_filename = self.filename_tmpl.format(runner.iter)
            cur_type, cur_time = 'iter', runner.iter

        # handle auto in self.key_indicators and self.rules before the loop
        if 'auto' in self.key_indicators:
            self._init_rule(self.rules, [list(metrics.keys())[0]])

        # save best logic
        # get score from messagehub
        for key_indicator, rule in zip(self.key_indicators, self.rules):
            prefix, key_score = self._get_metric_score(metrics, key_indicator)
            
            if prefix == 'teacher.':
                # 不保存训练过程中的teacher模型
                continue
            
            else:
                # 添加prefix 标记
                key_indicator = prefix + key_indicator  
                
                # 更新self.is_better_than 字典中的key
                if rule is not None:
                    self.is_better_than[key_indicator] = self.rule_map[rule]

                if len(self.key_indicators) == 1:
                    best_score_key = 'best_score'
                    runtime_best_ckpt_key = 'best_ckpt'
                    best_student_ckpt_path = self.best_student_ckpt_path
                else:
                    best_score_key = f'best_score_{key_indicator}'
                    runtime_best_ckpt_key = f'best_ckpt_{key_indicator}'
                    # best_ckpt_path = self.best_ckpt_path_dict[key_indicator]
                    best_student_ckpt_path = self.best_student_ckpt_path_dict[key_indicator]

                if best_score_key not in runner.message_hub.runtime_info:
                    best_score = self.init_value_map[rule]
                else:
                    best_score = runner.message_hub.get_info(best_score_key)

                if key_score is None or not self.is_better_than[key_indicator](
                        key_score, best_score):
                    continue

                best_score = key_score
                runner.message_hub.update_info(best_score_key, best_score)

                if best_student_ckpt_path and \
                self.file_client.isfile(best_student_ckpt_path) and \
                is_main_process():
                    self.file_client.remove(best_student_ckpt_path)
                    runner.logger.info(
                        f'The previous best checkpoint {best_student_ckpt_path} '
                        'is removed')
                    


                best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}'
                if len(self.key_indicators) == 1:
                    self.best_student_ckpt_path = self.file_client.join_path(  # type: ignore # noqa: E501
                        self.out_dir, best_ckpt_name)
                    runner.message_hub.update_info(runtime_best_ckpt_key,
                                                self.best_student_ckpt_path)
                else:
                    self.best_student_ckpt_path_dict[
                        key_indicator] = self.file_client.join_path(  # type: ignore # noqa: E501
                            self.out_dir, best_ckpt_name)
                    runner.message_hub.update_info(
                        runtime_best_ckpt_key,
                        self.best_student_ckpt_path_dict[key_indicator])
                runner.save_checkpoint(
                    self.out_dir,
                    filename=best_ckpt_name,
                    file_client_args=self.file_client_args,
                    save_optimizer=False,
                    save_param_scheduler=False,
                    by_epoch=False,
                    backend_args=self.backend_args)
                runner.logger.info(
                    f'The best checkpoint with {best_score:0.4f} {key_indicator} '
                    f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
                


  • 模型评估阶段得到的结果形式为student.和teacher.的形式,如student.mIoU
        for key, value in metrics.items():
            student_key = 'student.' + key
            teacher_key = 'teacher.' + key

            student_metrics[student_key] = value
            self.runner.message_hub.log_scalars.pop(f'val/{teacher_key}', None)

        self.runner.call_hook('after_val_epoch', metrics=student_metrics)  # 此处call hook
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值