Detectron2 交叉验证。直接计算mAP版本

该博客介绍了如何在Detectron2中实现交叉验证,用于检测模型过拟合。通过自定义ValidationMap类,每间隔一定迭代次数进行验证并记录最佳mAP。在训练过程中,当验证集上的mAP超过阈值时保存模型。尽管最终结果与直接测试可能略有不同,但可有效评估模型性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Detectron2 交叉验证

目的:在训练的过程中使用验证集合的结果老判断是否出现了过拟合现象。这个教程是在验证集上面直接评定指标,而不是计算验证集合的损失。关于将验证集的损失作为交叉验证的对象可看这里

实现原理

由于Detecton2将不属于模型运行的部分用Hook包装。所以我们可以注册Hook去实现交叉验证方法。关于Detectron2训练的一个简单梳理,可以看这里

参数设置

我们在配置文件中加入以下的字段(也可以直接在代码里面写入数值,但这不方便修改)。

_C.CROSS_VAL = CN()
_C.CROSS_VAL.STATUE = True	# 是否开启交叉验证
_C.CROSS_VAL.ITER = 100		# 多少次个iter后验证一次
_C.CROSS_VAL.LOGAP_THR = 53	# 记录最大mAp的阈值

交叉验证类实现

class ValidationMap(HookBase):
    def __init__(self, cfg, ):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.defrost()                  # enable it to can be changed.
        self.interval = cfg.CROSS_VAL.ITER
        self.data_loader = None
        self.evaluator = None
        self.max_mAP = cfg.CROSS_VAL.LOGAP_THR

    def after_step(self):
        if self.trainer.iter % self.interval == 0 and self.trainer.iter >= self.interval:
            self.trainer.model.eval()
            self.eval_dataset()
            self.trainer.model.train()

    def eval_dataset(self):
        if self.data_loader is None:
            self.data_loader = self.trainer.build_test_loader(self.cfg, self.cfg.DATASETS.TEST[0])
            self.evaluator = self.trainer.build_evaluator(self.cfg, self.cfg.DATASETS.TEST[0])
        result = inference_on_dataset(self.trainer.model,
                                      self.data_loader,
                                      self.evaluator)
        print_csv_format(result)
        for task, res in result.items():
            if isinstance(res, Mapping):
                # Don't print "AP-category" metrics since they are usually not tracked.
                important_res = []
                for k, v in res.items():
                    if "-" not in k:
                        important_res.append((k, v) if v != math.nan else (k, 0))
                self.print_and_log(task, important_res)
                self.save_result_or_not(important_res)

    def print_and_log(self, task, important_res):
        # log data
        if comm.is_main_process():
            self.trainer.storage.put_scalars(**dict(important_res))

    def save_result_or_not(self, important_res):
        if self.max_mAP < dict(important_res)['AP']:
            self.max_mAP = dict(important_res)['AP']
            self.trainer.checkpointer.save('base_model')

使用方法

train.py中导入我们写的hook。注册之后即可使用。

trainer = Trainer(cfg)
trainer.resume_or_load(resume=True)
if cfg.CROSS_VAL.STATUE:
    val_mAP = ValidationMap(cfg)
    trainer.register_hooks([val_mAP])
    trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]

注: 这里和最后通过命令行测出的mAP有点差距(0~1mAP)并且直接使用tainr.test( … )同样也有差距。但相对大小没有改变。希望有大佬能够在评论区指出。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值