Detectron2 交叉验证
目的:在训练的过程中使用验证集合的结果老判断是否出现了过拟合现象。这个教程是在验证集上面直接评定指标,而不是计算验证集合的损失。关于将验证集的损失作为交叉验证的对象可看这里。
实现原理
由于Detecton2将不属于模型运行的部分用Hook包装。所以我们可以注册Hook去实现交叉验证方法。关于Detectron2训练的一个简单梳理,可以看这里。
参数设置
我们在配置文件中加入以下的字段(也可以直接在代码里面写入数值,但这不方便修改)。
_C.CROSS_VAL = CN()
_C.CROSS_VAL.STATUE = True # 是否开启交叉验证
_C.CROSS_VAL.ITER = 100 # 多少次个iter后验证一次
_C.CROSS_VAL.LOGAP_THR = 53 # 记录最大mAp的阈值
交叉验证类实现
class ValidationMap(HookBase):
def __init__(self, cfg, ):
super().__init__()
self.cfg = cfg.clone()
self.cfg.defrost() # enable it to can be changed.
self.interval = cfg.CROSS_VAL.ITER
self.data_loader = None
self.evaluator = None
self.max_mAP = cfg.CROSS_VAL.LOGAP_THR
def after_step(self):
if self.trainer.iter % self.interval == 0 and self.trainer.iter >= self.interval:
self.trainer.model.eval()
self.eval_dataset()
self.trainer.model.train()
def eval_dataset(self):
if self.data_loader is None:
self.data_loader = self.trainer.build_test_loader(self.cfg, self.cfg.DATASETS.TEST[0])
self.evaluator = self.trainer.build_evaluator(self.cfg, self.cfg.DATASETS.TEST[0])
result = inference_on_dataset(self.trainer.model,
self.data_loader,
self.evaluator)
print_csv_format(result)
for task, res in result.items():
if isinstance(res, Mapping):
# Don't print "AP-category" metrics since they are usually not tracked.
important_res = []
for k, v in res.items():
if "-" not in k:
important_res.append((k, v) if v != math.nan else (k, 0))
self.print_and_log(task, important_res)
self.save_result_or_not(important_res)
def print_and_log(self, task, important_res):
# log data
if comm.is_main_process():
self.trainer.storage.put_scalars(**dict(important_res))
def save_result_or_not(self, important_res):
if self.max_mAP < dict(important_res)['AP']:
self.max_mAP = dict(important_res)['AP']
self.trainer.checkpointer.save('base_model')
使用方法
在train.py
中导入我们写的hook。注册之后即可使用。
trainer = Trainer(cfg)
trainer.resume_or_load(resume=True)
if cfg.CROSS_VAL.STATUE:
val_mAP = ValidationMap(cfg)
trainer.register_hooks([val_mAP])
trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
注: 这里和最后通过命令行测出的mAP有点差距(0~1mAP)并且直接使用tainr.test( … )同样也有差距。但相对大小没有改变。希望有大佬能够在评论区指出。