自定义peft训练器

peft_trainer 库定义了一个自定义的基于 transformers 库中的 Seq2SeqTrainer的训练器 PeftTrainer,。此外还包含 LogCallback 用于跟踪训练进度和记录关键信息。下面是对这些组件的编写思路、编写目的和作用的详细分析:

编写思路

  1. 继承基础训练器PeftTrainer 继承自 Seq2SeqTrainer,利用其提供的基础训练逻辑,同时扩展以支持参数效率训练(如使用 LoRA)和其他自定义训练行为。

  2. 增强日志记录功能LogCallback 类通过 TrainerCallback 捕捉训练过程中的关键事件,如损失、学习率和预计剩余时间等,并将这些信息保存到日志文件中,以便进行动态可视化和监控。

  3. 模型保存和加载定制:在 PeftTrainer 中重写 _save_load_best_model 方法,以支持复杂的模型保存和加载需求,尤其是处理带有 LoRA 适配器或其他特定微调策略的模型。

类定义

  • LogCallback

    • LogCallback 类继承自 TrainerCallback,用于在训练过程中记录日志。它包含一个 on_log 方法,当训练记录日志时被调用,将关键的训练参数(如步数、损失、学习率等)写入到文件 "trainer_log.jsonl" 中。这个文件用于动态可视化训练进度和结果。

    • 动态记录训练过程中的关键指标,并将它们保存到指定文件中,支持后续的分析和可视化。

    • 提供关于训练进度的实时反馈,包括当前步数、总步数、当前损失、学习率等。

        

class LogCallback(TrainerCallback):
    r"""
    TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
    The on_log function primarily collects process parameters during training, such as training loss, learning rate,
    and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
    time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
    purposes.
    """

    def __init__(self):
        self.start_time = time.time()

    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
        r"""
        Event called after logging the last logs.
        """
        if "loss" not in state.log_history[-1]:
            return
        cur_time = time.time()
        cur_steps = state.log_history[-1].get("step")
        elapsed_time = cur_time - self.start_time
        avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
        remaining_steps = state.max_steps - cur_steps
        remaining_time = remaining_steps * avg_time_per_step
        log_dict = {
            "current_steps": cur_steps,
            "total_steps": state.max_steps,
            "loss": state.log_history[-1].get("loss", None),
            "reward": state.log_history[-1].get("reward", None),
            "learning_rate": state.log_history[-1].get("learning_rate", None),
            "epoch": state.log_history[-1].get("epoch", None),
            "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
            "elapsed_time": str(timedelta(seconds=int(elapsed_time))),
            "remaining_time": str(timedelta(seconds=int(remaining_time)))
        }
        os.makedirs(args.output_dir, exist_ok=True)
        with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
            f.write(json.dumps(log_dict) + "\n")
  • PeftTrainer

    • 代码:

      • __init__: 初始化方法,接受 finetuning_args 和其他关键参数。它在初始化时会检查是否需要删除先前存在的训练日志文件。

      • _save: 保存模型检查点的方法。根据微调类型(如LoRA微调、冻结微调或全参数微调),它保存模型的不同部分:基础模型、值头(如果有的话)、tokenizer等。保存的模型文件按照约定存储在指定的输出目录中,并且将训练参数和微调参数分别保存到文件中。

      • _load_best_model: 加载最佳模型检查点的方法。根据微调类型,它加载相应的可训练参数,并根据需要设置值头(如果有的话)。这个方法也会记录日志,指示从哪个检查点加载了最佳模型。

    • 支持在使用 LoRA 或其他参数效率技术进行微调时,对模型的保存和加载进行特殊处理。

    • 允许在多种微调配置下灵活运行,包括全参数微调和参数冻结策略。

    • 在训练结束后,能够从最佳检查点中加载模型,确保使用验证过程中性能最好的模型版本。

组件作用

  • 支持高级微调技术PeftTrainer 使训练过程能够利用参数效率技术(如 LoRA),这对于处理大模型尤为重要,可以在不显著增加计算成本的情况下提高模型性能。

  • 提供详细的训练监控:通过 LogCallback,实时跟踪和记录训练过程中的各种指标,帮助后续更好地优化训练过程。

  • 定制化模型保存和加载:确保在特定的训练设置下,如使用特定微调策略时,模型的状态可以正确保存并在需要时恢复。

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值