Seq2SeqPeftTrainer:处理序列到序列(Seq2Seq)的生成任务

seq2seq库中定义一个专门的训练类Seq2SeqPeftTrainer和一个评估指标计算类ComputeMetrics。这些组件主要用于处理序列到序列(Seq2Seq)的生成任务,例如机器翻译、文本摘要等。下面是对这些组件的编写思路、编写目的和作用的详细分析:

编写思路

  1. 继承基础训练器Seq2SeqPeftTrainer继承自PeftTrainer,以利用其参数效率的训练功能,并添加特定于序列到序列任务的功能。

  2. 自定义指标计算:通过ComputeMetrics类,封装了生成任务的常用评估指标计算,如BLEU和ROUGE,使用jieba分词来处理中文文本。

  3. 定制化预测步骤和结果保存:在Seq2SeqPeftTrainer中,重写prediction_step方法以适应生成任务的需要,例如在生成后移除输入提示部分的tokens。还提供了一个save_predictions方法来保存模型的预测结果。

编写目的

  • 增强序列到序列任务的训练和评估能力:通过集成专门的评估指标和训练方法,增强模型在各种生成任务中的表现。

  • 提供详细的生成评估指标:通过计算BLEU和ROUGE等指标,详细评估模型生成文本的质量,支持中文环境的特定处理。

  • 优化模型输出的保存和分析:使研究者和开发者可以方便地分析和比较模型预测的输出,有助于模型的迭代和优化。

类定义

  • ComputeMetrics

    • 提供一个标准化的方法来计算重要的语言生成指标,帮助评价模型在具体任务上的表现。

    • 对于中文文本,通过结合jieba分词和ROUGE评估,提供了更适合中文语境的性能评估。

    @dataclass
    class ComputeMetrics:
        r"""
        Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
        用于计算生成任务的评估指标(如 ROUGE 和 BLEU)
        """
    ​
        tokenizer: PreTrainedTokenizer
    ​
        def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
            r"""
            Uses the model predictions to compute metrics.
            """
            preds, labels = eval_preds
            score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
    ​
            preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
            labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
    ​
            decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
            decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
    ​
            for pred, label in zip(decoded_preds, decoded_labels):
                hypothesis = list(jieba.cut(pred))
                reference = list(jieba.cut(label))
    ​
                if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
                    result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
                else:
                    rouge = Rouge()
                    scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
                    result = scores[0]
    ​
                for k, v in result.items():
                    score_dict[k].append(round(v["f"] * 100, 4))
    ​
                bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
                score_dict["bleu-4"].append(round(bleu_score * 100, 4))
    ​
            return {k: float(np.mean(v)) for k, v in score_dict.items()}

  • Seq2SeqPeftTrainer

    • 专为序列到序列模型设计,支持在训练和预测时处理特定于任务的需求,如生成文本的后处理。

    • 提供了保存预测结果的功能,使得结果分析和模型比较更加方便,支持将结果直接用于论文撰写或进一步的数据分析。

    class Seq2SeqPeftTrainer(PeftTrainer):
        r"""
        Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
        用于计算生成任务的评估指标
        """
    ​
        def prediction_step(
            self,
            model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]],
            prediction_loss_only: bool,
            ignore_keys: Optional[List[str]] = None,
        ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            r"""
            Removes the prompt part in the generated tokens.
    ​
            Subclass and override to inject custom behavior.
            """
            input_ids = inputs["input_ids"]
            loss, generated_tokens, labels = super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
            )
            generated_tokens = generated_tokens[:, input_ids.size(-1):] if generated_tokens is not None else None
            return (loss, generated_tokens, labels)
    ​
        def save_predictions(
                self,
                predict_results: PredictionOutput
        ) -> None:
            r"""
            Saves model predictions to `output_dir`.
    ​
            A custom behavior that not contained in Seq2SeqTrainer.
            """
            if not self.is_world_process_zero():
                return
    ​
            output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
            logger.info(f"Saving prediction results to {output_prediction_file}")
    ​
            preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
            labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
    ​
            decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    ​
            with open(output_prediction_file, "w", encoding="utf-8") as writer:
                res: List[str] = []
                for pred, label in zip(decoded_preds, decoded_labels):
                    res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
                writer.write("\n".join(res))

seq2seq库通过提供专门针对序列到序列任务的训练器和评估工具,极大地增强了这类模型在文本生成任务中的应用能力和评估效率。这对需要高质量文本生成输出的应用场景提供支持,如自动文摘、机器翻译等。

  • 22
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值