从抽象到扩展--设计灵活的训练框架

训练基本流程

class Trainer():
	def __init__():
	...
	#定义基本参数
	...
	#定义三件套
		self.model=create_model()
		self.optimizer=create_optimizer()
		self.data_loader=create_dataloader()
		
		self.learning_rate_adjuster=create_lr_adjuster()
		
		self.saver = create_saver()
		
		self.writer = create_tensorboard_writer()
		
	 def train(self):
        iteration = 0
        for self.iter in range(0, self.max_iter):
            # 训练三部曲
            # step1 数据加载
            data = next(self.data_loader)
        
            # step2 loss 计算
            loss , acc , other_info = self.model(data)
        
            # step3 反向传播
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            iteration += 1
            if iteration % self.save_iter == 0:
                self.saver.save(model)
            if iteration % self.log_iter == 0:
                self.writer.log('loss',loss)
                self.writer.log('acc',acc)
                self.writer.log('other_info',other_info)

抽象训练流程

每一个”任务“都定义为独立的”hook“,每一个hook 都会实现自己的 方法。

class HookBase:

    def before_train(self):
        """
        Called before the first iteration.
        """
        pass

    def after_train(self):
        """
        Called after the last iteration.
        """
        pass

    def before_step(self):
        """
        Called before each iteration.
        """
        pass

    def after_step(self):
        """
        Called after each iteration.
        """
        pass
class TrainerBase:

    def __init__(self):
        self._hooks = []

    def register_hooks(self, hooks):
        hooks = [h for h in hooks if h is not None]
        for h in hooks:
            assert isinstance(h, HookBase)
            h.trainer = weakref.proxy(self)
        self._hooks.extend(hooks)

    def train(self, start_iter: int, max_iter: int):
        self.iter = self.start_iter = start_iter
        self.max_iter = max_iter
        with EventStorage(start_iter) as self.storage:
            try:
                self.before_train()
                for self.iter in range(start_iter, max_iter):
                    self.before_step()
                    self.run_step()
                    self.after_step()
            finally:
                self.after_train()

    def before_train(self):
        for h in self._hooks:
            h.before_train()

    def after_train(self):
        for h in self._hooks:
            h.after_train()

    def before_step(self):
        for h in self._hooks:
            h.before_step()

    def after_step(self):
        for h in self._hooks:
            h.after_step()
        # this guarantees, that in each hook's after_step, storage.iter == trainer.iter
        self.storage.step()

    def run_step(self):
        raise NotImplementedError

基于hook对训练代码进行进一步的抽象。

class Saver(HookBase):
    def __init__(self, save_iter):
        self.save_iter = save_iter
    def after_step(self):
        if self.trainer.iter % self.save_iter == 0:
            save_model(self.trainer.model)

class Writer(HookBase):
    def __init__(self,write_iter):
        self._debug_info = {}
        self.write_iter = write_iter
        self.writer = TensorboardWriter(...)
    def before_step(self):
        self._debug_info = {}
    def after_step(self):
        loss = self._debug_info['loss']
        self.writer.write(loss)

class Trainer(TrainerBase):
    def __init__(self):
        self.hooks : List[HookBase] = self.register_hooks()
    def register_hooks(self):
        self.hooks = []
        self.hooks.append(Saver(save_iter))
        self.hooks.append(Writer(write_iter))
        for h in hooks:
            assert isinstance(h, HookBase)
            h.trainer = weakref.proxy(self)
    def before_step(self):
        for hook in self.hooks:
             hook.before_step()
    def run_step(self):
        self.iter += 1
        data = next(self.data_loader)
        # step2 loss 计算
        loss , acc , other_info = self.model(data)   
        # step3 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    def after_step(self):
        for hook in self.hooks:
             hook.after_step()

1、Trainer在初始化时注册一系列的hooks , 每个hook 可以完成一个工作
2、注册hooks 的时候,通过 h.trainer = weakref.proxy(self) 把自身变为 hooks的属性,使 得hook中可以通过 h.trainer.iter 获取trainer内部记录的一些训练状态相关的信息

detectron2 设计思想

间接调用函数与多级继承

class SimpleTrainer(TrainerBase):

    def __init__(self, model, data_loader, optimizer):

        super().__init__()

        # 注意到为了灵活性,这里仍然没有定义 data_loader , model 和 optimizer 
        # 仍然是采用了 加载的方式,而真正定义这些的类,会在下一节中介绍
        model.train()
        self.model = model
        self.data_loader = data_loader
        self._data_loader_iter = iter(data_loader)
        self.optimizer = optimizer

    def run_step(self):
        # 通过next 方法获取数据
        data = next(self._data_loader_iter)
        data_time = time.perf_counter() - start

        # 执行前向代码
        loss_dict = self.model(data)
        losses = sum(loss for loss in loss_dict.values())
        self._detect_anomaly(losses, loss_dict)

        metrics_dict = loss_dict
        metrics_dict["data_time"] = data_time
        self._write_metrics(metrics_dict)

        # 进行反向传播
        self.optimizer.zero_grad()
        losses.backward()
        self.optimizer.step()

    def _detect_anomaly(self, losses, loss_dict):
        if not torch.isfinite(losses).all():
            raise FloatingPointError(
                "Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
                    self.iter, loss_dict
                )
            )
class DefaultTrainer(SimpleTrainer):


    def __init__(self, cfg):
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(
                model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            )
        super().__init__(model, data_loader, optimizer)

        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
        ...


    def build_hooks(self): 
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN
        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            ...
        ]

        if comm.is_main_process():
            ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
  
        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers()))
        return ret

    def build_writers(self):

        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):

        super().train(self.start_iter, self.max_iter)
        if hasattr(self, "_last_eval_results") and comm.is_main_process():
            verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

class TextTrainer(DefaultTrainer):
    def build_train_loader(self, cfg):
        # 重写这个方法就好
        text_mapper = BoundaryMapper(cfg)
        data_loader = build_detection_train_loader(cfg, text_mapper)
        return data_loader

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        text_mapper = BoundaryTestMapper(cfg)
        test_data_loader = build_detection_test_loader(cfg, dataset_name, mapper=text_mapper)
        return test_data_loader
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Gallant Hu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值