defval(self, data_loader,**kwargs):
self.model.eval()
self.mode ='val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2)# Prevent possible deadlock during epoch transitionfor i, data_batch inenumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')with torch.no_grad():if self.batch_processor isNone:
outputs = self.model.val_step(data_batch, self.optimizer,**kwargs)else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=False,**kwargs)ifnotisinstance(outputs,dict):raise TypeError('"batch_processor()" or "model.val_step()"'' must return a dict')if'log_vars'in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
run,包括训练/验证工作
输入参数包括 data_loaders/workflow,两者的长度相同,分别对应。
workflow 加入是 [('train', 2), ('val', 1)],则表示train 2 epoch then val 1 epoch,按照这个顺序依次进行训练,作为一个epoch。
后续会根据 workflow 根据 mode 选择对应的 train/val 方法。
defrun(self, data_loaders, workflow, max_epochs,**kwargs):"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""assertisinstance(data_loaders,list)assert mmcv.is_list_of(workflow,tuple)assertlen(data_loaders)==len(workflow)
self._max_epochs = max_epochs
for i, flow inenumerate(workflow):
mode, epochs = flow
if mode =='train':
self._max_iters = self._max_epochs *len(data_loaders[i])break
work_dir = self.work_dir if self.work_dir isnotNoneelse'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')while self.epoch < max_epochs:for i, flow inenumerate(workflow):
mode, epochs = flow
ifisinstance(mode,str):# self.train()ifnothasattr(self, mode):raise ValueError(
f'runner has no method named "{mode}" to run an ''epoch')
epoch_runner =getattr(self, mode)else:raise TypeError('mode in workflow must be a str, but got {}'.format(type(mode)))for _ inrange(epochs):if mode =='train'and self.epoch >= max_epochs:return
epoch_runner(data_loaders[i],**kwargs)
time.sleep(1)# wait for some hooks like loggers to finish
self.call_hook('after_run')