transformers库中torch的train loop设计

文章详细介绍了transformers库中的训练流程,包括训练、评估、预测和保存等环节的关键参数和操作。TrainingArguments类用于设置工作参数,Trainer类负责执行训练,支持GPU和TPU的分布式训练,以及混合精度等加速特性。同时,文章提到了自定义损失计算、评估指标和回调函数(TrainerCallback)的设计,以适应不同任务需求。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一. 任务描述

完整的训练流程有多个环节, 每个环节都有自己要关注的任务, 列举如下:

  • 训练. 关注的有 {batch_size 控制, 分布式相关, 记 log, ckpt 恢复续跑, warmup} 等.
  • 评估. 关注的有 {评估时机, batch_size 控制, 记 log, loss 之外的 metric 计算} 等.
  • 预测. 关注的有 {如何嵌入在评估环节, batch_size 控制, 记 log} 等.
  • 保存. 关注的有 {保存时机, early stop, last n 个检查点维护} 等.

工作较为琐碎. 我用 train loop 统一代指, 因为主流程在 while epoch < total_epoch , for step,batch_data in enumerate(dataset) 内, 是一个 loop.

transformers 是一个流行的预训练模型库, 我想学习并复用它的 train loop 设计.

3600 行的 Trainer.py

适配了大量的场景, 有各种 if 判断, 所以显得很啰嗦. 比如

  • 设备: gpu, tpu
  • 数据并行: dp, ddp
  • 模型并行: deepspeed支持
  • 加速特性: 混合精度

二. TrainingArguments

用于指定不同环节的工作参数, 与 Trainer 搭配使用.

实例化

直接构造个 TrainingArguments 对象更直观.

# debug_flag 可用于cpu开发环境或生产gpu单机环境的调试
debug_flag = True
TrainingArguments(remove_unused_columns=False,
      seed=42,
      learning_rate=5e-5,
      label_smoothing_factor=0,
      per_device_train_batch_size=1000 if not debug_flag else 4,
      num_train_epochs=1,
      log_level='info',

      logging_dir=flags.output_dir + 'logdir',
      logging_strategy=IntervalStrategy.STEPS,
      logging_steps=500 if not debug_flag else 50,

      evaluation_strategy=IntervalStrategy.STEPS,
      per_device_eval_batch_size=2000 if not debug_flag else 10,
      # 这里是 train_loop 中的 eval 间隔
      eval_steps=20000 if not debug_flag else 60,
      label_names=['label'],

      output_dir=flags.output_dir + 'model_save_dir',
      save_steps=100000 if not debug_flag else 500,
      save_total_limit=3,
      # optim='xx'
      )

一些参数解读如下:

  • log_level, 只接受字 str 类型, 控制 trainer logger 的日志级别.
  • evaluation_strategy, 有三种取值
    • “no”, 默认取值, 训练期间不 eval
    • “steps”, 按照另一个参数 “eval_steps” 指定的步数间隔作评估
    • “epoch”, 每个 epoch 结束做一次评估

源码解读

class TrainingArguments:
    def _setup_devices(): 
        torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
        if device.type == "cuda":
        torch.cuda.set_device(device)

    @property
    def train_batch_size(self) -> int:
        return train_batch_size = per_device_batch_size * max(1, self.n_gpu)

Q1: per_device_batch_size 与 train_batch_size 的关系是怎样的?
A: 见源码中的 train_batch_size() 方法, 应该只有DP模式(单机多卡, 单进程多线程)下会出现 train_batch_size ≠ per_device_batch_size 的情况.

三. Trainer

数据集这里只感知 dataset, data_collator 这两个对象, 化繁为简, 不再需要 dataloader.

实例化

eval_dataset_dict = {"算auc的dataset": eval_dataset}
trainer: Trainer = Trainer(
    model=my_model,
    args=training_args,  # training arguments, defined above
    train_dataset=train_dataset,
    eval_dataset=eval_dataset_dict,
    data_collator=data_collator,
    compute_metrics=my_compute_metrics_fn
)

源码解读

class Trainer:
    def __init__(model, 
    	args: TrainingArguments,
    	data_collator,
		train_dataset: Union[Dataset, IterableDataset],
		eval_dataset,
		compute_loss_func: Callable,
		compute_metrics: Callable[[EvalPrediction], Dict],
		callbacks: List[TrainerCallback],
		optimizers	
	):
        self._move_model_to_device(model, args.device)

    def get_train_dataloader(self) -> DataLoader:
        # 若 TrainingArguments.remove_unused_columns 为 False, 则直接使用
        data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
            if self.args.world_size > 1:
                # 这里作分片
                train_dataset = IterableDatasetShard(...)
            return DataLoader(...)
        return DataLoader(
            train_dataset,
            batch_size=self._train_batch_size,
            sampler=train_sampler,
            collate_fn=data_collator,...)

    def _inner_training_loop(self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None):
        self._train_batch_size = batch_size
        train_dataloader = self.get_train_dataloader()
        model = self._wrap_model(self.model_wrapped)
        logger.info("***** Running training *****")
        if resume_from_checkpoint:
        	self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
        	logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        for epoch in range(epochs_trained, num_train_epochs):
            epoch_iterator = train_dataloader
            for step, inputs in enumerate(epoch_iterator):
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
                    
                tr_loss_step = self.training_step(model, inputs)
                self.current_flos += float(self.floating_point_ops(inputs))
                if (step + 1) % args.gradient_accumulation_steps == 0 :
                	self.optimizer.step()
                	model.zero_grad()
                	self.state.global_step += 1
                	# 这里会更新 step 的 tqdm 进度条
                	self.control = self.callback_handler.on_step_end(args, self.state, self.control)
					self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
					
		logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
		self.control = self.callback_handler.on_train_end(args, self.state, self.control)
		return TrainOutput(self.state.global_step, train_loss, metrics)


    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        model.train()
        inputs = self._prepare_inputs(inputs)
        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)
            loss.backward()
            
        return loss.detach()

    # 这个函数可以友好重写
    def compute_loss(self, model, inputs, return_outputs=False):
        loss = self.label_smoother(outputs, labels)

    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
        if self.control.should_evaluate:
            if isinstance(self.eval_dataset, dict):
                # 不同的 dataset 计算不同的 metric
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
                    metrics = self.evaluate(
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
            else:
                # 注意这里不再传入 eval_dataset
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) 

evaluate

class Trainer:
        
    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
    """A dictionary containing the evaluation loss and the potential metrics """
    	eval_dataloader = self.get_eval_dataloader(eval_dataset)
    	eval_loop = self.evaluation_loop
    	output:EvalLoopOutput = eval_loop(
            eval_dataloader,
            description="Evaluation",
            prediction_loss_only=True if self.compute_metrics is None else None,
            ignore_keys=ignore_keys,
            metric_key_prefix=metric_key_prefix,
        )
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
        self.log(output.metrics)
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
        return output.metrics

        
    def evaluation_loop(self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",    
    ) -> EvalLoopOutput:
    	model = self._wrap_model(self.model, training=False, dataloader=dataloader)
    	batch_size = self.args.eval_batch_size
    	logger.info(f"***** Running {description} *****")
    	logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        logger.info(f"  Batch size = {batch_size}")
        model.eval()
        
        for step, inputs in enumerate(dataloader):
	        # Prediction step
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None

			# Update containers on host
            if loss is not None:
            	# loss.shape is [], losses .shape is [batch_size*N]
                losses = self._nested_gather(loss.repeat(batch_size))
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
			if labels is not None:
				同上
			if inputs_decode is not None:
				同上
			if logits is not None:
				同上
	
        
        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
		return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)

predict

用户代码可以这么写:

# 控制内部的 has_label 参数为 False, 这样避开计算 loss, 只预测
training_args.label_names = []
trainer = MyTrainer(model=model, args=training_args,
                      data_collator=data_collator
                      )
# 从指定目录中找最新的 ckpt
resume_from_checkpoint = trainer_utils.get_last_checkpoint(training_args.output_dir)
logger.info(f"resume_from_checkpoint={resume_from_checkpoint}")
# 虽然是受保护的方法, 但还是要调用
# 内部实现是: model 用内存中的, tensor 值用 ckpt 中的 state_dict.
trainer._load_from_checkpoint(resume_from_checkpoint)
trainer.predict(test_dataset=eval_dataset)

再看源码, 依旧用的是 evaluation_loop .

class Trainer:
    def predict(
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
    ) -> PredictionOutput:
	    test_dataloader = self.get_test_dataloader(test_dataset)
	    eval_loop = self.evaluation_loop
     	output = eval_loop(
        test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
    	)
    	total_batch_size = self.args.eval_batch_size * self.args.world_size
    	self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
    	
    	return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)


    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        has_labels = False if len(self.label_names) == 0 else True
        inputs = self._prepare_inputs(inputs)
        if has_labels: pass
        else: 
        	labels = None
		with torch.no_grad():
			if has_labels:
				with self.compute_loss_context_manager():
	            	loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
                loss = loss.mean().detach()
            else:
                loss = None
                  with self.compute_loss_context_manager():
                      outputs = model(**inputs)
                  if isinstance(outputs, dict):
                      logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                  else:
                      logits = outputs
                      
        if prediction_loss_only:
            return (loss, None, None)

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)

四. Metric

对于分类任务来讲, 只有 eval_loss 是不够直观的, 还要有 auc, gauc 等指标. 这时就需要自己写计算代码.

compute_metrics

Trainer.__init__(compute_metrics) 参数传入自己写的方法.
接口约定是这样的:

  • compute_metrics_fn(eval_prediction:EvalPrediction)->Dict[str:float]

有时有时需要计算分组聚合的指标, 如 gauc, 需要传入 user_id, 就要用到 inputs, 例子见下:

def compute_metrics_gauc(eval_prediction: trainer_utils.EvalPrediction):
	# 是 ndarray 类型
	preds, labels = eval_prediction.predictions, eval_prediction.label_ids
	inputs: List[str] = eval_prediction.inputs
	""" 此处省略计算逻辑 """
    return {
        'predict_mean': preds.mean(),
        'label_mean': labels.mean(),
        'auc': auc,
        'gauc': gauc
    }

EvalPrediction 与 EvalLoopOutput 源码

class EvalPrediction:
    def __init__(
        self,
        predictions: Union[np.ndarray, Tuple[np.ndarray]],
        label_ids: Union[np.ndarray, Tuple[np.ndarray]],
        inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
    ):
    	pass
	
	def __iter__(self):
    	return iter((self.predictions, self.label_ids))

    def __getitem__(self, idx):
    	if idx == 0:
            return self.predictions
        elif idx == 1:
            return self.label_ids
        elif idx == 2:
            return self.inputs


class EvalLoopOutput(NamedTuple):
    predictions: Union[np.ndarray, Tuple[np.ndarray]]
    label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
    metrics: Optional[Dict[str, float]]
    num_samples: Optional[int]

五. TrainerCallback

TrainerCallback 是一个接口, 位于 trainer_callback.py, 规范了不同时机的回调方法, 用于触发特定的行为.
该设计模式可减少 train loop 中的代码侵入, 又不失行为扩展的灵活性.

  • self.control=trainer.callback_handler(trainerState, trainerControl )
  • self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
class TrainerCallback:
	def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): pass
	def on_train_begin()
	def on_train_end()
	def on_step_begin()
	def on_step_end()
	# substep 用于 gradient accumulation 场景
	def on_substep_end()
	def on_evaluate()
	def on_predict()
	def on_save()
	# 对应 TrainerControl.should_log
	def on_log()

辅助类

TrainerControl 与 TrainerState

@dataclass
class TrainerControl:
	"""A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
    switches in the training loop"""
    
    should_training_stop: bool = False
    should_epoch_stop: bool = False
    should_save: bool = False
    should_evaluate: bool = False
    should_log: bool = False

@dataclass
class TrainerState:
    epoch: Optional[float] = None
    global_step: int = 0
    max_steps: int = 0

trainer.init() 中关于 callback 的逻辑见下, 用到了 DefaultFlowCallback 与 CallbackHandler 两个辅助类:

DEFAULT_CALLBACKS = [DefaultFlowCallback]
class Trainer:
	def __init__(..., callbacks):
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
		callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
        )
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)

CallbackHandler

是 TrainerCallback 的特殊子类, 用于 编排 callbacks_list 中的各个对象.

class CallbackHandler(TrainerCallback):
    def __init__(self, callbacks, model, processing_class, optimizer, lr_scheduler):
        self.callbacks = []
        for cb in callbacks:
            self.add_callback(cb)
    def call_event(self, event, args, state, control, **kwargs):
        for callback in self.callbacks:
            result = getattr(callback, event)(...)
        return control
            

DefaultFlowCallback

如果给 trainer 初始化传参的 callbacks 为空, 则默认用这个.
handles the default flow of the training loop for logs, evaluation and checkpoints.

业务类

ProgressCallback

用于 training or evaluation 的进度条更新.

class ProgressCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.is_local_process_zero:
            self.training_bar.update(state.global_step - self.current_step)
            self.current_step = state.global_step

TensorBoardCallback

库中提供的实现, 来记 tensorboard. 缺点是 只支持 scalar.

  • log_dir, 依照 TrainingArguments 中的相应参数
  • 记录内容,
  • 多进程特点: 只会在满足 state.is_world_process_zero 的主进程中作记录.
class TensorBoardCallback(TrainerCallback):
	def __init__(self, tb_writer=None):
		has_tensorboard = is_tensorboard_available()
		if has_tensorboard:
			from torch.utils.tensorboard import SummaryWriter
			self._SummaryWriter = SummaryWriter
		self.tb_writer = tb_writer
		
	def _init_summary_writer(self, args, log_dir=None):
	    log_dir = log_dir or args.logging_dir
        if self._SummaryWriter is not None:
            self.tb_writer = self._SummaryWriter(log_dir=log_dir)
            
	@override
    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return
        if self.tb_writer is None:
            self._init_summary_writer(args, log_dir)
	
	@override
    def on_log(self, args, state, control, logs=None, **kwargs):
        if not state.is_world_process_zero:
            return
        if self.tb_writer is not None:
        	for k, v in logs.items():
        		self.tb_writer.add_scalar(k, v, state.global_step)
        	self.tb_writer.flush()
        	
    @override
    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()
            self.tb_writer = None

EarlyStoppingCallback

todo

自定义callback

todo.

六. 分布式相关

自带 DP 与 DDP 支持. 逻辑在 _inner_training_loop() 内的 model = self._wrap_model(self.model_wrapped)中 , 源码见下.

class Trainer:
	def _wrap_model(self, model, training=True, dataloader=None):
		# Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = nn.DataParallel(model)
        if self.sharded_ddp is not None: pass
        elif self.fsdp is not None:	pass
        elif is_sagemaker_dp_enabled(): pass
        elif self.args.local_rank != -1:
            kwargs = {}
            if self.args.ddp_find_unused_parameters is not None:
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
            else:
	            kwargs["find_unused_parameters"] = True
                
	        model = nn.parallel.DistributedDataParallel(
	                model,
	                device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
	                output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
	                **kwargs,
	        )

注意 DDP 中, find_unused_parameters=True 代表着跨设备对梯度求平均, 这样训练才有意义. 实测遇到下方日志, 应该不用管.

[W reducer.cpp:1251] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass.

参考

  1. Trainer 官方doc, todo
### ATIS Corpus Intent Classification Code Example 对于基于ATIS语料进行意图分类的任务,通常会采用序列模型如循环神经网络(RNN)、长短时记忆网络(LSTM),以及近年来流行的Transformer架构来处理自然语言理解中的槽位填充和意图识别问题。下面给出一段使用PyTorch实现的简单LSTM模型来进行意图分类的例子: ```python import torch from torch import nn, optim from torch.utils.data import DataLoader, Dataset from transformers import BertTokenizerFast class IntentsDataset(Dataset): def __init__(self, sentences, labels, tokenizer, max_len=50): self.sentences = sentences self.labels = labels self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.sentences) def __getitem__(self, idx): sentence = str(self.sentences[idx]) label = int(self.labels[idx]) encoding = self.tokenizer.encode_plus( sentence, add_special_tokens=True, max_length=self.max_len, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label, dtype=torch.long) } class LSTMIntentClassifier(nn.Module): def __init__(vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, batch_first=True, dropout=dropout if n_layers>1 else 0.) self.fc = nn.Linear(hidden_dim * (2 if bidirectional else 1), output_dim) self.dropout = nn.Dropout(dropout) def forward(x, lengths=None): embedded = self.dropout(self.embedding(x)) packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), enforce_sorted=False, batch_first=True) packed_output, (hidden, cell) = model.lstm(packed_embedded) hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1) if model.bidirectional else hidden[-1,:,:]) logits = self.fc(hidden.squeeze(0)) return logits # Initialize parameters and load data here... tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') train_dataset = IntentsDataset(sentences_train, intents_train, tokenizer) test_dataset = IntentsDataset(sentences_test, intents_test, tokenizer) device = "cuda" if torch.cuda.is_available() else "cpu" model = LSTMIntentClassifier(...).to(device=device) def train(model, dataloader, optimizer, criterion): epoch_loss = 0. epoch_acc = 0. model.train() for batch in dataloader: inputs = {k:v.to(device) for k,v in batch.items()} outputs = model(inputs["input_ids"], inputs["attention_mask"]) loss = criterion(outputs.logits, inputs["labels"]) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() if __name__ == "__main__": # Training loop goes here... ``` 此代码片段展示了如何构建一个简单的LSTM用于意图分类,并利用预训练的语言模型BERT作为分词器。需要注意的是,在实际应用中还需要考虑更多细节,比如超参数调整、数据增强等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值