MMfewshot之runner

runner是用于管理整个(训练+验证)、测试流程的类。

mmfewshot 中的runners根据任务需求,大致划分为EpochBasedRunner IterBasedRunner,两者均继承自BaseRunner

1 BaseRunner

BaseRunner实现了初始化方法,预定义了如下抽象方法,需要在子类中实现:

  • run()
  • train()
  • val()
  • save_checkpoint()

初始化方法接收 modeloptimizerlogger,并检验各个对象的类别,将他们保存为自己的类成员。model可以 为wrapper或模型本身。model必须具备train_setp方法。

BaseRunner实现了很多 hook注册方法,比如register_hook(self)。该方法接收一个hook对象,并根据其 priorty插入到self._hooks中。

提供了很多方法,用以访问内部属性情况,比较重要的有get_hook_info(self),返回runner在不同周期内部调用hook的顺序。

2 EpochBasedRunner

继承自BaseRunner,除开run() train() val() save_checkpoint()外,为满足 resume 的需求,也实现了相关方法。且实现了train_iter()方法,该方法在train()中循环调用,用以处理单个batch

2.1 EpochBasedRunner在框架中的调用

train.py为例。

  • parse args;生成cfg dict类对象;构建任务路径;初始化model;初始化train_dataset,或有val_dataset
  • 调用 train_model API
    train_model(
            model,
            datasets,
            cfg,
            distributed=distributed,
            validate=(not args.no_validate),
            timestamp=timestamp,
            device='cpu' if args.device == 'cpu' else 'cuda',
            meta=meta)
    

within train_model:

  • 构建 train_dataloader
  • 根据cfg,用MMDataParallel/MMDistributedDataParallel wrap moel
  • 构建 optimizer
  • 实例化runner
     runner = build_runner(
             cfg.runner,
             default_args=dict(
                 model=model,
                 batch_processor=None,
                 optimizer=optimizer,
                 work_dir=cfg.work_dir,
                 logger=logger,
                 meta=meta))
    
  • 注册training hooksrunner.register_training_hooks() ,该方法接收 cfg 字典
  • 如果需要validate,build val_datasetval_dataset一般是MetaTestDataset(EpisodicDataset)类,其内部的__getitem__方法根据self._mode有不同的返回值。self._mode有3种设定值,test_setsupportquery。设置为test_setdataset返回单个的数据样本,这个mode是用于遍历所有的测试样本,提取feature并保存,用以加速后续的测试环节。
  • build_meta_test_dataloader(val_dataset, meta_test_cfg)构建validate dataloader。不管在pretrainingmeta training阶段,validate都是采用episodic的方式,这与meta testing的数据pipeline是相同的,不用过分纠结函数命名问题。meta test就是指episodic validate。注意,valdatasetself._mode有三种模式,比起构建一个dataloader,频繁的切换self._mode,直接采用深度拷贝构建3中不同模式的dataloader在验证阶段使用更为方便。一般默认使用episodic的方式进行validate,所以val_datasets都是EpsodicDataset的子类。
  • 实例化 MetaTestEvalHook。并将其注册入 runner._hooksMetaTestEvalHook用于模型验证,在初始化阶段接收上述3个 dataloade,以及validate cfg。MetaTestEvalHook实现了若干 hook方法,较为重要的是after_train_epochafter_train_iter。此外,实现了evaluate_save_ckpt,在model 进行一个 epoch / batchforward后,会分别调用这两个hook。根据 validate cfg 的差异,MetaTestEvalHook 会在 after_train_epochafter_train_iter 方法中选择调用 evaluate 方法,进行模型验证。一般而言,EpochBasedRunner 选择在 after_train_epoch 中进行验证;iterBasedRunner则在 after_train_iter中进行验证。如果验证模型具有更高的精度,则在evaluate中调用_save_ckpt进行保存。pretraining 阶段使用 EpochBasedRunnermeta training阶段使用IterBasedRunner
  • 注册default hooscustom hooksrunner中。
  • runner.resume / runner.load_checkpoint 重启训练或读取CKPT
  • runner.run(data_loaders, cfg.workflow) 开启训练。注意,若cfg.workflow('train',1),此处传入的data_loaders只包含train_dataloader。因为3个val_dataloader已经保存在MetaTestEvalHook的类成员中了,并且注册到runner._hooks

within runnner.run():
开始迭代data_loader:

  • self.call_hook('before_run')
  • while self.epoch < self._max_epochs:
    • 根据传入的cfg.workflow中的 str mode,来判断当前任务的类型(train/val),是执行runner.train还是runner.valepoch_runner = getattr(self, mode)但一般 str mode都置为train, validate在hook中进行。
    • 调用epoch_runner(data_loader)runner.trainrunner.val。一般都只调用runnner.train,所以下文只讨论前者。
  • self.call_hook('after_run')

within runner.train():

  • self.call_hook('before_train_epoch')
  • for i, data_batch in enumerate(self.data_loader):
    • self.call_hook('before_train_iter')
    • self.run_iter(data_batch, train_mode=True, **kwargs)
    • self.call_hook('after_train_iter'), 在IterBasedRunner中, MetaTestEvalHook会于此调用 evaluate()
  • self.call_hook('after_train_epoch'), 在EpochBasedRunner中, MetaTestEvalHook会于此调用 evaluate()

within runner.run_iter(data_batch, train_mode=True):

  • 根据 bool train_mode 判断执行self.model.train_step()/self.model.val_step()
  • 保存model 的输出。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值