detectron2的训练过程大致就是hook+trainer的过程,具体可以查看detectron2/engine/train_loop.py中的源码,HookBase基类和Trainerbase基类
Hook的主要内容就是针对深度学习训练、测试过程中的不同需求,定义了很多个不同的Hook
,用来处理训练之前、之后需要准备、收尾的工作。其中包括了计算时间的IterationTimer
、按一定周期输出结果的PeriodicWriter
、调整学习率的LRScheduler
等。当然我们也可以自己添加hook
Trainerbase的工作就是遍历这些hook,依次执行,不过这个遍历有四种,分别是before_train
,after_train
,before_step
,after_step
。还有一个就是run_step
,这个函数其实就是平常我们在编写训练过程的代码,例如读数据,训练模型,获取损失值,求导数,反向梯度更新等,只不过在这个类里面没有定义。
detectron2定义了两种Trainer,都是继承于Trainerbase
1.SimpleTrainer :提供了最简单的 单损失,单优化器,单数据集的 训练循环,没有其它的任何功能(包括保存,记录等),这些功能可以通过 hook 来实现。
2.DefaultTrainer :在 tools/train_net.py 和很多脚本中被使用。DefaultTrainer 从 config 初始化,包含了一些允许用户自定义的更加标准化的操作,比如优化器的选择、学习率的规划、记录日志、保存模型、评测模型等。
下面直接上demo:
import detectron2.data.transforms as T
from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.config import get_cfg
cfg = get_cfg()
config_file = "configs/RCNN/Base-RCNN.yaml"
cfg.merge_from_file(config_file)
cfg.DATASETS.TRAIN = ("smoke_train",)
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.001
cfg.INPUT.MAX_SIZE_TRAIN = 400
cfg.INPUT.MAX_SIZE_TEST = 400
cfg.INPUT.MIN_SIZE_TRAIN = (160,)
cfg.INPUT.MIN_SIZE_TEST = 160
cfg.OUTPUT_DIR = "./test_cocodata_register"
mapper = DatasetMapper(cfg,is_train=True)
train_loader = build_detection_train_loader(cfg,mapper=mapper)
from detectron2.engine import DefaultTrainer
print(cfg)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()