训练基本流程
class Trainer():
def __init__():
...
#定义基本参数
...
#定义三件套
self.model=create_model()
self.optimizer=create_optimizer()
self.data_loader=create_dataloader()
self.learning_rate_adjuster=create_lr_adjuster()
self.saver = create_saver()
self.writer = create_tensorboard_writer()
def train(self):
iteration = 0
for self.iter in range(0, self.max_iter):
# 训练三部曲
# step1 数据加载
data = next(self.data_loader)
# step2 loss 计算
loss , acc , other_info = self.model(data)
# step3 反向传播
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
iteration += 1
if iteration % self.save_iter == 0:
self.saver.save(model)
if iteration % self.log_iter == 0:
self.writer.log('loss',loss)
self.writer.log('acc',acc)
self.writer.log('other_info',other_info)
抽象训练流程
每一个”任务“都定义为独立的”hook“,每一个hook 都会实现自己的 方法。
class HookBase:
def before_train(self):
"""
Called before the first iteration.
"""
pass
def after_train(self):
"""
Called after the last iteration.
"""
pass
def before_step(self):
"""
Called before each iteration.
"""
pass
def after_step(self):
"""
Called after each iteration.
"""
pass
class TrainerBase:
def __init__(self):
self._hooks = []
def register_hooks(self, hooks):
hooks = [h for h in hooks if h is not None]
for h in hooks:
assert isinstance(h, HookBase)
h.trainer = weakref.proxy(self)
self._hooks.extend(hooks)
def train(self, start_iter: int, max_iter: int):
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
finally:
self.after_train()
def before_train(self):
for h in self._hooks:
h.before_train()
def after_train(self):
for h in self._hooks:
h.after_train()
def before_step(self):
for h in self._hooks:
h.before_step()
def after_step(self):
for h in self._hooks:
h.after_step()
# this guarantees, that in each hook's after_step, storage.iter == trainer.iter
self.storage.step()
def run_step(self):
raise NotImplementedError
基于hook对训练代码进行进一步的抽象。
class Saver(HookBase):
def __init__(self, save_iter):
self.save_iter = save_iter
def after_step(self):
if self.trainer.iter % self.save_iter == 0:
save_model(self.trainer.model)
class Writer(HookBase):
def __init__(self,write_iter):
self._debug_info = {}
self.write_iter = write_iter
self.writer = TensorboardWriter(...)
def before_step(self):
self._debug_info = {}
def after_step(self):
loss = self._debug_info['loss']
self.writer.write(loss)
class Trainer(TrainerBase):
def __init__(self):
self.hooks : List[HookBase] = self.register_hooks()
def register_hooks(self):
self.hooks = []
self.hooks.append(Saver(save_iter))
self.hooks.append(Writer(write_iter))
for h in hooks:
assert isinstance(h, HookBase)
h.trainer = weakref.proxy(self)
def before_step(self):
for hook in self.hooks:
hook.before_step()
def run_step(self):
self.iter += 1
data = next(self.data_loader)
# step2 loss 计算
loss , acc , other_info = self.model(data)
# step3 反向传播
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def after_step(self):
for hook in self.hooks:
hook.after_step()
1、Trainer在初始化时注册一系列的hooks , 每个hook 可以完成一个工作
2、注册hooks 的时候,通过 h.trainer = weakref.proxy(self) 把自身变为 hooks的属性,使 得hook中可以通过 h.trainer.iter 获取trainer内部记录的一些训练状态相关的信息
detectron2 设计思想
间接调用函数与多级继承
class SimpleTrainer(TrainerBase):
def __init__(self, model, data_loader, optimizer):
super().__init__()
# 注意到为了灵活性,这里仍然没有定义 data_loader , model 和 optimizer
# 仍然是采用了 加载的方式,而真正定义这些的类,会在下一节中介绍
model.train()
self.model = model
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer
def run_step(self):
# 通过next 方法获取数据
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
# 执行前向代码
loss_dict = self.model(data)
losses = sum(loss for loss in loss_dict.values())
self._detect_anomaly(losses, loss_dict)
metrics_dict = loss_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)
# 进行反向传播
self.optimizer.zero_grad()
losses.backward()
self.optimizer.step()
def _detect_anomaly(self, losses, loss_dict):
if not torch.isfinite(losses).all():
raise FloatingPointError(
"Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
self.iter, loss_dict
)
)
class DefaultTrainer(SimpleTrainer):
def __init__(self, cfg):
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
# For training, wrap with DDP. But don't need this for inference.
if comm.get_world_size() > 1:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
super().__init__(model, data_loader, optimizer)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
...
def build_hooks(self):
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(self.optimizer, self.scheduler),
...
]
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
if comm.is_main_process():
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers()))
return ret
def build_writers(self):
return [
# It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
]
def train(self):
super().train(self.start_iter, self.max_iter)
if hasattr(self, "_last_eval_results") and comm.is_main_process():
verify_results(self.cfg, self._last_eval_results)
return self._last_eval_results
class TextTrainer(DefaultTrainer):
def build_train_loader(self, cfg):
# 重写这个方法就好
text_mapper = BoundaryMapper(cfg)
data_loader = build_detection_train_loader(cfg, text_mapper)
return data_loader
@classmethod
def build_test_loader(cls, cfg, dataset_name):
text_mapper = BoundaryTestMapper(cfg)
test_data_loader = build_detection_test_loader(cfg, dataset_name, mapper=text_mapper)
return test_data_loader