detectron2(目标检测框架)无死角玩转-06:源码详解(2)-Trainer继承关系,Hook

detectron2 同时被 2 个专栏收录
11 篇文章 10 订阅
27 篇文章 1 订阅

以下链接是个人关于detectron2(目标检测框架),所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文 末 附 带 \color{blue}{文末附带} 公 众 号 − \color{blue}{公众号 -} 海 量 资 源 。 \color{blue}{ 海量资源}。

detectron2(目标检测框架)无死角玩转-00:目录

前言

通过前面的博客,已经知道detectron2的整体架构,源码我们再回溯到detectron2/engine/train_loop.py,可以看到TrainerBase,我们来看看他存在那些子孙如下:

# 老祖宗 detectron2/engine/train_loop.py
class TrainerBase: 

#第一代子孙 detectron2/engine/train_loop.py
class SimpleTrainer(TrainerBase): 

# 第二代子孙 detectron2/engine/defaults.py
class DefaultTrainer(SimpleTrainer): 

# 第三代子孙 tools/train_my.py-本人参考源码实现
class Trainer(DefaultTrainer):

我们先从老祖宗看起:

class TrainerBase:

    def __init__(self):
        self._hooks = []

    def register_hooks(self, hooks):
        hooks = [h for h in hooks if h is not None]
        for h in hooks:
            assert isinstance(h, HookBase)
            h.trainer = weakref.proxy(self)
        self._hooks.extend(hooks)

    def train(self, start_iter: int, max_iter: int):

        self.iter = self.start_iter = start_iter
        self.max_iter = max_iter

        with EventStorage(start_iter) as self.storage:
            try:
                self.before_train()
                for self.iter in range(start_iter, max_iter):
                    self.before_step()
                    self.run_step()
                    self.after_step()
            finally:
                self.after_train()

    def before_train(self):
        for h in self._hooks:
            h.before_train()

    def after_train(self):
        for h in self._hooks:
            h.after_train()

    def before_step(self):
        for h in self._hooks:
            h.before_step()

    def after_step(self):
        for h in self._hooks:
            h.after_step()
        # this guarantees, that in each hook's after_step, storage.iter == trainer.iter
        self.storage.step()

    def run_step(self):
        raise NotImplementedError

我这里删减了很多注释,大家可以阅读一下源码的英文注释。总的来说,还是很简单的,首先需要实现如下函数:

# 已经实现
def after_train(self):  def after_train(self): def before_step(self):

# 已定义,待子类实现
def run_step(self):
    raise NotImplementedError 

通过源码为我们可以知道,after_train,after_train,before_step他们的实现过程真的很简单,就是循环调用 self._hooks中对应的函数,那么self._hooks是什么东西呢?翻译过来为钩子!不急我们先放一放,其中的实现的

def register_hooks(self, hooks)

也放在后面一起讲解,我们先来看看他的第一代子孙class SimpleTrainer(TrainerBase):,其重写了def run_step(self),实现了

	# 检测异常
    def _detect_anomaly(self, losses, loss_dict):
    
    # 简单的看作日志记录即可
    def _write_metrics(self, metrics_dict: dict):

很明显,核心部分为def run_step(self),重写如下:

    def run_step(self):
        """
        Implement the standard training logic described above.
        """
        
        # 确定为训练模式
        assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
        start = time.perf_counter()
        
        """
        # 获取一个batch_size的数据,如果有必要,是可以对dataloader进行装饰的
        If your want to do something with the data, you can wrap the dataloader.
        """
        data = next(self._data_loader_iter)
        data_time = time.perf_counter() - start

        """
        # 如果有必要,可以重写loss的计算过程
        If your want to do something with the losses, you can wrap the model.
        """
        loss_dict = self.model(data)
        losses = sum(loss for loss in loss_dict.values())
        
        # 检测loss计算是否异常
        self._detect_anomaly(losses, loss_dict)

        # 写入log日志
        metrics_dict = loss_dict
        metrics_dict["data_time"] = data_time
        self._write_metrics(metrics_dict)

        """
        # 进行反向传播
        If you need accumulate gradients or something similar, you can
        wrap the optimizer with your custom `zero_grad()` method.
        """
        self.optimizer.zero_grad()
        losses.backward()

        """
        # 一次迭代完成
        If you need gradient clipping/scaling or other processing, you can
        wrap the optimizer with your custom `step()` method.
        """
        self.optimizer.step()

其实还是很好理解的,一路分析到这里,已经完成了反向传播。我们继续分析,看看其第二代子孙class DefaultTrainer(SimpleTrainer),路径为detectron2/engine/defaults.py,是有点复杂吧,不过关系不大,我们慢慢分析就好,再其初始化函数中我们又看到了

        model = self.build_model(cfg) # 构建模型
        optimizer = self.build_optimizer(cfg, model) # 构建优化方式
        data_loader = self.build_train_loader(cfg) # 构建训练数据迭代器

很熟悉的,DefaultTrainer主要实现可如下函数:

    # 继续训练,或者重新加载模型
    def resume_or_load(self, resume=True):

    # 构建和训练相关的hooks
    def build_hooks(self):
	
	# 主要调用了父类的train
    def train(self):
    
	# 根据cfg构建网络模型
    def build_model(cls, cfg):

	# 构建SGD优化器
    def build_optimizer(cls, cfg, model):
    
	# 定义学习率衰减方式 
    def build_lr_scheduler(cls, cfg, optimizer):

	# 构建训练数据迭代器
	def build_train_loader(cls, cfg):

	# 构建测试数据迭代器
    def build_test_loader(cls, cfg, dataset_name):

	# 用于训练过程中,进行验证,主意,这里为空,并没有实现
    def build_evaluator(cls, cfg, dataset_name):

	# 对数据进行测试
	def test(cls, cfg, model, evaluators=None):

可以看到,第三代子孙的功能基本以及很完善了,也就是剩下

def build_evaluator(cls, cfg, dataset_name):

需要之类重写,除此之外,还有一个重点,那当然就是:

    # 构建和训练相关的hooks
    def build_hooks(self):

我们暂且先放一下,来看看第四代子孙,也就是本人仿写tools/train_my.py中的class Trainer(DefaultTrainer),实现了:

# 根据cfg配置,构建评估器
def build_evaluator(cls, cfg, dataset_name, output_folder=None):

# 这个我就是抄过来的,暂时不知道给来做什么的
def test_with_TTA(cls, cfg, model):

到这里,我们把祖宗到第三代都稍微过了一一遍,现在,还有一个重点,那就是Hook了。

Hook

首先,我们第一次提到Hook,是在祖宗TrainerBase的初始化函数之中:

class TrainerBase:
    def __init__(self):
        self._hooks = []

    def register_hooks(self, hooks):
        hooks = [h for h in hooks if h is not None]
        for h in hooks:
            assert isinstance(h, HookBase)
            h.trainer = weakref.proxy(self)
        self._hooks.extend(hooks)
    def before_train(self):
        for h in self._hooks:
            h.before_train()

    def after_train(self):
        for h in self._hooks:
            h.after_train()

    def before_step(self):
        for h in self._hooks:
            h.before_step()

    def after_step(self):
        for h in self._hooks:
            h.after_step()
        # this guarantees, that in each hook's after_step, storage.iter == trainer.iter
        self.storage.step()

从这里可以很明确的看到,self._hooks列表中,存在着很多hook,当调用before_train,after_train,before_step其会循环调用self._hooks列表中所有hook对应的函数,def register_hooks(self, hooks),就是把hook注册到self._hooks列表中,我们先来看:

class HookBase:
    def before_train(self):
        """
        Called before the first iteration.
        """
        pass

    def after_train(self):
        """
        Called after the last iteration.
        """
        pass

    def before_step(self):
        """
        Called before each iteration.
        """
        pass

    def after_step(self):
        """
        Called after each iteration.
        """
        pass

似乎没有什么好看的,定义了几个函数,但是都没有实现实际上的东西,那么我们在源码中查看一下,其在那些地方被调用了:
在这里插入图片描述
可以看到,在源码中,HookBase的子类还是非常多的,其都是在detectron2/engine/hooks.py中实现:

# 可以自定义回调函数
class CallbackHook(HookBase):

# 对训练过程中的时间进行记录,追踪
class IterationTimer(HookBase):

# 迭代之前和迭代之后周期性写入
class PeriodicWriter(HookBase):

# 周期性的进行检查
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):

# 学习率调整策略,每次迭完之后,都进行,判断是否达到学习率改变条件
class LRScheduler(HookBase):

# 迭代到指定次数,则进行评估
class EvalHook(HookBase):

# 可以简单理解为BN的升级版本
class PreciseBN(HookBase):

这里,为大家做一个简单的介绍,如果后续使用到这些hook再做详细的介绍。其实这些hook是很有的一个点子,大家在做消融实验的时候可以使用到。

总的来说,我们可以创建各种各样的hook,只要该hook继承于HookBase,就能通过TrainerBase.register_hooks进行注册,每个hook可以实现一下几个函数:

def before_train(self):
def after_train(self):
def before_step(self):
def after_step(self):

在这里,我们拿class EvalHook(HookBase)来举一个例子,该类实现是为了对训练中的模型进行测试,一般来说,测试都在迭代一定次数之后,再进行验证,所以其重写了函数:

    def after_step(self):

迭代达到指定次数后就会进行测试,其初始化函数如下:

    def __init__(self, eval_period, eval_function):
        self._period = eval_period
        self._func = eval_function

其传入了两个参数,一个是验证周期,一个是验证(测试)函数。

结语

到这里,对于整体的把控,又更近一步了,下小结我们就来看看数据的预处理过程,也就是训练数据的迭代器。

在这里插入图片描述

  • 10
    点赞
  • 2
    评论
  • 6
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

相关推荐
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值