在seq2seq
库中定义一个专门的训练类Seq2SeqPeftTrainer
和一个评估指标计算类ComputeMetrics
。这些组件主要用于处理序列到序列(Seq2Seq)的生成任务,例如机器翻译、文本摘要等。下面是对这些组件的编写思路、编写目的和作用的详细分析:
编写思路
-
继承基础训练器:
Seq2SeqPeftTrainer
继承自PeftTrainer
,以利用其参数效率的训练功能,并添加特定于序列到序列任务的功能。 -
自定义指标计算:通过
ComputeMetrics
类,封装了生成任务的常用评估指标计算,如BLEU和ROUGE,使用jieba分词来处理中文文本。 -
定制化预测步骤和结果保存:在
Seq2SeqPeftTrainer
中,重写prediction_step
方法以适应生成任务的需要,例如在生成后移除输入提示部分的tokens。还提供了一个save_predictions
方法来保存模型的预测结果。
编写目的
-
增强序列到序列任务的训练和评估能力:通过集成专门的评估指标和训练方法,增强模型在各种生成任务中的表现。
-
提供详细的生成评估指标:通过计算BLEU和ROUGE等指标,详细评估模型生成文本的质量,支持中文环境的特定处理。
-
优化模型输出的保存和分析:使研究者和开发者可以方便地分析和比较模型预测的输出,有助于模型的迭代和优化。
类定义
-
ComputeMetrics
:-
提供一个标准化的方法来计算重要的语言生成指标,帮助评价模型在具体任务上的表现。
-
对于中文文本,通过结合jieba分词和ROUGE评估,提供了更适合中文语境的性能评估。
@dataclass class ComputeMetrics: r""" Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. 用于计算生成任务的评估指标(如 ROUGE 和 BLEU) """ tokenizer: PreTrainedTokenizer def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: r""" Uses the model predictions to compute metrics. """ preds, labels = eval_preds score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) for pred, label in zip(decoded_preds, decoded_labels): hypothesis = list(jieba.cut(pred)) reference = list(jieba.cut(label)) if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} else: rouge = Rouge() scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) result = scores[0] for k, v in result.items(): score_dict[k].append(round(v["f"] * 100, 4)) bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) score_dict["bleu-4"].append(round(bleu_score * 100, 4)) return {k: float(np.mean(v)) for k, v in score_dict.items()}
-
-
Seq2SeqPeftTrainer
:-
专为序列到序列模型设计,支持在训练和预测时处理特定于任务的需求,如生成文本的后处理。
-
提供了保存预测结果的功能,使得结果分析和模型比较更加方便,支持将结果直接用于论文撰写或进一步的数据分析。
class Seq2SeqPeftTrainer(PeftTrainer): r""" Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. 用于计算生成任务的评估指标 """ def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: r""" Removes the prompt part in the generated tokens. Subclass and override to inject custom behavior. """ input_ids = inputs["input_ids"] loss, generated_tokens, labels = super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys ) generated_tokens = generated_tokens[:, input_ids.size(-1):] if generated_tokens is not None else None return (loss, generated_tokens, labels) def save_predictions( self, predict_results: PredictionOutput ) -> None: r""" Saves model predictions to `output_dir`. A custom behavior that not contained in Seq2SeqTrainer. """ if not self.is_world_process_zero(): return output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") logger.info(f"Saving prediction results to {output_prediction_file}") preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True) with open(output_prediction_file, "w", encoding="utf-8") as writer: res: List[str] = [] for pred, label in zip(decoded_preds, decoded_labels): res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) writer.write("\n".join(res))
-
seq2seq
库通过提供专门针对序列到序列任务的训练器和评估工具,极大地增强了这类模型在文本生成任务中的应用能力和评估效率。这对需要高质量文本生成输出的应用场景提供支持,如自动文摘、机器翻译等。