一. 任务描述
完整的训练流程有多个环节, 每个环节都有自己要关注的任务, 列举如下:
- 训练. 关注的有 {batch_size 控制, 分布式相关, 记 log, ckpt 恢复续跑, warmup} 等.
- 评估. 关注的有 {评估时机, batch_size 控制, 记 log, loss 之外的 metric 计算} 等.
- 预测. 关注的有 {如何嵌入在评估环节, batch_size 控制, 记 log} 等.
- 保存. 关注的有 {保存时机, early stop, last n 个检查点维护} 等.
工作较为琐碎. 我用 train loop 统一代指, 因为主流程在 while epoch < total_epoch , for step,batch_data in enumerate(dataset)
内, 是一个 loop.
transformers 是一个流行的预训练模型库, 我想学习并复用它的 train loop 设计.
3600 行的 Trainer.py
适配了大量的场景, 有各种 if 判断, 所以显得很啰嗦. 比如
- 设备: gpu, tpu
- 数据并行: dp, ddp
- 模型并行: deepspeed支持
- 加速特性: 混合精度
二. TrainingArguments
用于指定不同环节的工作参数, 与 Trainer 搭配使用.
实例化
直接构造个 TrainingArguments 对象更直观.
# debug_flag 可用于cpu开发环境或生产gpu单机环境的调试
debug_flag = True
TrainingArguments(remove_unused_columns=False,
seed=42,
learning_rate=5e-5,
label_smoothing_factor=0,
per_device_train_batch_size=1000 if not debug_flag else 4,
num_train_epochs=1,
log_level='info',
logging_dir=flags.output_dir + 'logdir',
logging_strategy=IntervalStrategy.STEPS,
logging_steps=500 if not debug_flag else 50,
evaluation_strategy=IntervalStrategy.STEPS,
per_device_eval_batch_size=2000 if not debug_flag else 10,
# 这里是 train_loop 中的 eval 间隔
eval_steps=20000 if not debug_flag else 60,
label_names=['label'],
output_dir=flags.output_dir + 'model_save_dir',
save_steps=100000 if not debug_flag else 500,
save_total_limit=3,
# optim='xx'
)
一些参数解读如下:
- log_level, 只接受字 str 类型, 控制 trainer logger 的日志级别.
- evaluation_strategy, 有三种取值
- “no”, 默认取值, 训练期间不 eval
- “steps”, 按照另一个参数 “eval_steps” 指定的步数间隔作评估
- “epoch”, 每个 epoch 结束做一次评估
源码解读
class TrainingArguments:
def _setup_devices():
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
if device.type == "cuda":
torch.cuda.set_device(device)
@property
def train_batch_size(self) -> int:
return train_batch_size = per_device_batch_size * max(1, self.n_gpu)
Q1: per_device_batch_size 与 train_batch_size 的关系是怎样的?
A: 见源码中的 train_batch_size() 方法, 应该只有DP模式(单机多卡, 单进程多线程)下会出现 train_batch_size ≠ per_device_batch_size 的情况.
三. Trainer
数据集这里只感知 dataset, data_collator 这两个对象, 化繁为简, 不再需要 dataloader.
实例化
eval_dataset_dict = {"算auc的dataset": eval_dataset}
trainer: Trainer = Trainer(
model=my_model,
args=training_args, # training arguments, defined above
train_dataset=train_dataset,
eval_dataset=eval_dataset_dict,
data_collator=data_collator,
compute_metrics=my_compute_metrics_fn
)
源码解读
class Trainer:
def __init__(model,
args: TrainingArguments,
data_collator,
train_dataset: Union[Dataset, IterableDataset],
eval_dataset,
compute_loss_func: Callable,
compute_metrics: Callable[[EvalPrediction], Dict],
callbacks: List[TrainerCallback],
optimizers
):
self._move_model_to_device(model, args.device)
def get_train_dataloader(self) -> DataLoader:
# 若 TrainingArguments.remove_unused_columns 为 False, 则直接使用
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
if isinstance(train_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
# 这里作分片
train_dataset = IterableDatasetShard(...)
return DataLoader(...)
return DataLoader(
train_dataset,
batch_size=self._train_batch_size,
sampler=train_sampler,
collate_fn=data_collator,...)
def _inner_training_loop(self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None):
self._train_batch_size = batch_size
train_dataloader = self.get_train_dataloader()
model = self._wrap_model(self.model_wrapped)
logger.info("***** Running training *****")
if resume_from_checkpoint:
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader
for step, inputs in enumerate(epoch_iterator):
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
tr_loss_step = self.training_step(model, inputs)
self.current_flos += float(self.floating_point_ops(inputs))
if (step + 1) % args.gradient_accumulation_steps == 0 :
self.optimizer.step()
model.zero_grad()
self.state.global_step += 1
# 这里会更新 step 的 tqdm 进度条
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
return TrainOutput(self.state.global_step, train_loss, metrics)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
loss.backward()
return loss.detach()
# 这个函数可以友好重写
def compute_loss(self, model, inputs, return_outputs=False):
loss = self.label_smoother(outputs, labels)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
# 不同的 dataset 计算不同的 metric
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
)
else:
# 注意这里不再传入 eval_dataset
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
evaluate
class Trainer:
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
"""A dictionary containing the evaluation loss and the potential metrics """
eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_loop = self.evaluation_loop
output:EvalLoopOutput = eval_loop(
eval_dataloader,
description="Evaluation",
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
self.log(output.metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics
def evaluation_loop(self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
batch_size = self.args.eval_batch_size
logger.info(f"***** Running {description} *****")
logger.info(f" Num examples = {self.num_examples(dataloader)}")
logger.info(f" Batch size = {batch_size}")
model.eval()
for step, inputs in enumerate(dataloader):
# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
# Update containers on host
if loss is not None:
# loss.shape is [], losses .shape is [batch_size*N]
losses = self._nested_gather(loss.repeat(batch_size))
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if labels is not None:
同上
if inputs_decode is not None:
同上
if logits is not None:
同上
# Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
)
return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
predict
用户代码可以这么写:
# 控制内部的 has_label 参数为 False, 这样避开计算 loss, 只预测
training_args.label_names = []
trainer = MyTrainer(model=model, args=training_args,
data_collator=data_collator
)
# 从指定目录中找最新的 ckpt
resume_from_checkpoint = trainer_utils.get_last_checkpoint(training_args.output_dir)
logger.info(f"resume_from_checkpoint={resume_from_checkpoint}")
# 虽然是受保护的方法, 但还是要调用
# 内部实现是: model 用内存中的, tensor 值用 ckpt 中的 state_dict.
trainer._load_from_checkpoint(resume_from_checkpoint)
trainer.predict(test_dataset=eval_dataset)
再看源码, 依旧用的是 evaluation_loop .
class Trainer:
def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
) -> PredictionOutput:
test_dataloader = self.get_test_dataloader(test_dataset)
eval_loop = self.evaluation_loop
output = eval_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
has_labels = False if len(self.label_names) == 0 else True
inputs = self._prepare_inputs(inputs)
if has_labels: pass
else:
labels = None
with torch.no_grad():
if has_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
else:
loss = None
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
if prediction_loss_only:
return (loss, None, None)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
return (loss, logits, labels)
四. Metric
对于分类任务来讲, 只有 eval_loss 是不够直观的, 还要有 auc, gauc 等指标. 这时就需要自己写计算代码.
compute_metrics
Trainer.__init__
(compute_metrics) 参数传入自己写的方法.
接口约定是这样的:
- compute_metrics_fn(eval_prediction:EvalPrediction)->Dict[str:float]
有时有时需要计算分组聚合的指标, 如 gauc, 需要传入 user_id, 就要用到 inputs, 例子见下:
def compute_metrics_gauc(eval_prediction: trainer_utils.EvalPrediction):
# 是 ndarray 类型
preds, labels = eval_prediction.predictions, eval_prediction.label_ids
inputs: List[str] = eval_prediction.inputs
""" 此处省略计算逻辑 """
return {
'predict_mean': preds.mean(),
'label_mean': labels.mean(),
'auc': auc,
'gauc': gauc
}
EvalPrediction 与 EvalLoopOutput 源码
class EvalPrediction:
def __init__(
self,
predictions: Union[np.ndarray, Tuple[np.ndarray]],
label_ids: Union[np.ndarray, Tuple[np.ndarray]],
inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
):
pass
def __iter__(self):
return iter((self.predictions, self.label_ids))
def __getitem__(self, idx):
if idx == 0:
return self.predictions
elif idx == 1:
return self.label_ids
elif idx == 2:
return self.inputs
class EvalLoopOutput(NamedTuple):
predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
metrics: Optional[Dict[str, float]]
num_samples: Optional[int]
五. TrainerCallback
TrainerCallback 是一个接口, 位于 trainer_callback.py, 规范了不同时机的回调方法, 用于触发特定的行为.
该设计模式可减少 train loop 中的代码侵入, 又不失行为扩展的灵活性.
- self.control=trainer.callback_handler(trainerState, trainerControl )
- self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
class TrainerCallback:
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): pass
def on_train_begin()
def on_train_end()
def on_step_begin()
def on_step_end()
# substep 用于 gradient accumulation 场景
def on_substep_end()
def on_evaluate()
def on_predict()
def on_save()
# 对应 TrainerControl.should_log
def on_log()
辅助类
TrainerControl 与 TrainerState
@dataclass
class TrainerControl:
"""A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
switches in the training loop"""
should_training_stop: bool = False
should_epoch_stop: bool = False
should_save: bool = False
should_evaluate: bool = False
should_log: bool = False
@dataclass
class TrainerState:
epoch: Optional[float] = None
global_step: int = 0
max_steps: int = 0
trainer.init() 中关于 callback 的逻辑见下, 用到了 DefaultFlowCallback 与 CallbackHandler 两个辅助类:
DEFAULT_CALLBACKS = [DefaultFlowCallback]
class Trainer:
def __init__(..., callbacks):
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
CallbackHandler
是 TrainerCallback 的特殊子类, 用于 编排 callbacks_list 中的各个对象.
class CallbackHandler(TrainerCallback):
def __init__(self, callbacks, model, processing_class, optimizer, lr_scheduler):
self.callbacks = []
for cb in callbacks:
self.add_callback(cb)
def call_event(self, event, args, state, control, **kwargs):
for callback in self.callbacks:
result = getattr(callback, event)(...)
return control
DefaultFlowCallback
如果给 trainer 初始化传参的 callbacks
为空, 则默认用这个.
handles the default flow of the training loop for logs, evaluation and checkpoints.
业务类
ProgressCallback
用于 training or evaluation 的进度条更新.
class ProgressCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar.update(state.global_step - self.current_step)
self.current_step = state.global_step
TensorBoardCallback
库中提供的实现, 来记 tensorboard. 缺点是 只支持 scalar.
- log_dir, 依照 TrainingArguments 中的相应参数
- 记录内容,
- 多进程特点: 只会在满足
state.is_world_process_zero
的主进程中作记录.
class TensorBoardCallback(TrainerCallback):
def __init__(self, tb_writer=None):
has_tensorboard = is_tensorboard_available()
if has_tensorboard:
from torch.utils.tensorboard import SummaryWriter
self._SummaryWriter = SummaryWriter
self.tb_writer = tb_writer
def _init_summary_writer(self, args, log_dir=None):
log_dir = log_dir or args.logging_dir
if self._SummaryWriter is not None:
self.tb_writer = self._SummaryWriter(log_dir=log_dir)
@override
def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
if self.tb_writer is None:
self._init_summary_writer(args, log_dir)
@override
def on_log(self, args, state, control, logs=None, **kwargs):
if not state.is_world_process_zero:
return
if self.tb_writer is not None:
for k, v in logs.items():
self.tb_writer.add_scalar(k, v, state.global_step)
self.tb_writer.flush()
@override
def on_train_end(self, args, state, control, **kwargs):
if self.tb_writer:
self.tb_writer.close()
self.tb_writer = None
EarlyStoppingCallback
todo
自定义callback
todo.
六. 分布式相关
自带 DP 与 DDP 支持. 逻辑在 _inner_training_loop() 内的 model = self._wrap_model(self.model_wrapped)
中 , 源码见下.
class Trainer:
def _wrap_model(self, model, training=True, dataloader=None):
# Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
model = nn.DataParallel(model)
if self.sharded_ddp is not None: pass
elif self.fsdp is not None: pass
elif is_sagemaker_dp_enabled(): pass
elif self.args.local_rank != -1:
kwargs = {}
if self.args.ddp_find_unused_parameters is not None:
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
else:
kwargs["find_unused_parameters"] = True
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)
注意 DDP 中, find_unused_parameters=True
代表着跨设备对梯度求平均, 这样训练才有意义. 实测遇到下方日志, 应该不用管.
[W reducer.cpp:1251] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass.
参考
- Trainer 官方doc, todo