TRL里面GRPOTrainer中grpo_train.py文件详解

下面是一篇面向中文读者的博客,旨在全面介绍 grpo_train.py 的整体结构和关键流程,帮助你快速理解它在做什么、如何运行,以及在 GRPO(Group Relative Policy Optimization)训练流程中扮演什么角色。
Source: anaconda3/envs/xxx/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py

在这里插入图片描述


1. 背景:GRPO 是什么?

GRPO(Group Relative Policy Optimization) 是一种针对大语言模型(LLM)的强化学习训练方法,可看作 PPO 的一种改进变体。

  • PPO 中通常需要一个与策略模型相当规模的价值函数(critic),但在大型模型上这一点开销很大;
  • GRPO 则通过一次性对同一个 prompt 采样多条回答(一个 “Group”)并进行组内比较,不再需要显式的价值函数;
  • 大幅降低显存和算力需求,也使流程更直观:在同一 prompt 上把多条回答的 reward 进行对比、求相对优势(advantage),直接更新策略模型。

具体可参考笔者的另一篇博客:深度解析DeepSeek原论文中的 GRPO:带 clip 操作的完整公式与示例代码
grpo_train.py 正是针对这个算法的一个 PyTorch / Transformers 生态中的示例式实现,使你能够用“策略模型 + 参考模型 +(可选的)奖励模型”的方式来训练大语言模型。


2. 文件概览

文件中最核心的类是 GRPOTrainer(Trainer),继承自 transformers.Trainer。它重写或扩展了若干方法,包括:

  1. __init__:初始化模型、参考模型(ref_model)、奖励模型(reward_funcs)等,并作一些超参数设置(如 num_generations, beta 等)。
  2. _prepare_inputs:在训练循环中,每一个 batch 先对 prompt 进行采样生成多条回答,调用奖励模型打分,计算组内相对优势。
  3. compute_loss:根据 GRPO 公式,结合 KL 惩罚项和相对优势,计算最终损失并进行反向传播。
  4. prediction_step:训练 / 验证阶段如何调用 _prepare_inputs 并获取 loss。
  5. logcreate_model_card:日志与模型卡部分,可上传到 Hugging Face Hub 做模型管理。

3. 运行逻辑全流程

在实际运行训练时,GRPOTrainer.train() 会在循环中反复调用“采样 + 计算损失 + 更新模型”这条主线逻辑。

3.1 初始化(__init__

  • 加载策略模型model
    用户可以传入一个字符串(表示从 huggingface.co 加载某个 Causal LM),或一个已经初始化好的模型对象。
  • 加载参考模型ref_model
    默认会克隆一份与策略模型同样初始化的 ref_model,也可在某些配置下跳过 / 用 PEFT 方式禁用;
    用于后续计算 KL 惩罚,约束更新不要跑太远。
  • 加载奖励函数(reward_funcs
    • 可以是预训练好的 “SequenceClassification” 模型;
    • 可以是自定义 Python 函数(对回答打分);
    • 或者一个列表,意味着多种奖励加起来一起用。
  • 关键超参数
    • num_generations(同一个 prompt 上一次性采样多少回答,G)
    • max_prompt_lengthmax_completion_length 用于截断长度
    • beta:KL 正则项系数
    • use_vllm:是否用 vLLM 做推理
  • 准备数据集
    继承自 Trainer 的方式,传入 train_dataseteval_dataset 即可。

3.2 _prepare_inputs:采样 + 打分 + 计算相对优势

具体解析请参考笔者的另一篇博客:GRPO 与 TRL实现的GRPOTrainer中_prepare_inputs函数详解

这是 GRPO 的关键。简要流程:

  1. 对 batch 中每个 prompt(如 batch_size=8):
    • 调用模型一次性生成 num_generations 条回答(比如 G=4),所以最终会产生 8 * 4 = 32 条回答。
  2. 对 EOS:如果生成提早出现 EOS token,会用一个 mask(completion_mask)来屏蔽无效 token。
  3. 参考模型 log prob:用 ref_model 对完整序列(prompt + completion)计算 token 级对数概率,用于后面做 KL。
  4. 调用奖励模型:对每条回答打出一个 reward(可以是多个函数相加),形成 [B*G] 形状的 reward。
  5. 相对优势:把这 [B*G] reshape 成 [B, G],对同一个 prompt 的 G 条回答做“均值、标准差”,再 broadcast 回去,以得到每条回答的相对 advantage = ( ( r i − μ ) / σ (r_i - \mu)/\sigma (riμ)/σ)。
  6. 返回:把处理好的 prompt_ids, completion_ids, completion_mask, ref_per_token_logps, advantages 等信息一并打包,等待下一步在 compute_loss 中使用。

3.3 compute_loss:根据 GRPO 公式求 Loss

具体请参考笔者的另一篇博客:从公式到代码:DeepSeek大模型GRPO算法中的 compute_loss如何实现(基于TRL源代码)

compute_loss 中会做以下事情:

  1. 当前策略的 token-level log prob:对 [prompt + completion] 做前向,拿到 logits。
  2. KL 惩罚:参考策略与当前策略的差异,用一个近似公式(per_token_kl)。
  3. 相对优势:之前在 _prepare_inputs 已经算好 advantages
  4. PPO / GRPO style 的公式
    l o s s = − ( exp ⁡ ( log ⁡ p θ − log ⁡ p θ . d e t a c h ( ) ) × A ^ − β × K L ) . \mathrm{loss} = -\left( \exp(\log p_\theta - \log p_\theta.detach()) \times \hat{A}- \beta \times \mathrm{KL} \right). loss=(exp(logpθlogpθ.detach())×A^β×KL).
    并且用 completion_mask 遮住无效 token,对 batch 做平均。
  5. 返回一个标量 loss,供 PyTorch 做 .backward() 与更新。

3.4 训练循环

  • Trainer 带的主循环 train() 会不停地取一个 batch -> _prepare_inputs -> compute_loss -> 反向传播 -> 更新参数 -> 下一个 batch …
  • KL 系数 beta 会让模型别离参考策略太远,而“组内相对奖励”让它学会偏向分数更高的回答。
  • 如此周而复始,就完成了 GRPO 训练。

4. 主要函数/方法亮点

4.1 _set_signature_columns_if_needed

因为我们会在 _prepare_inputs 里自定义处理数据,并不直接依赖典型的 “model inputs” 结构,所以这里覆盖了父类 Trainer 的默认逻辑,告诉它只关心 "prompt" 这个字段即可。

4.2 _get_per_token_logps

具体参考笔者的另一篇博客:(trl的grpo_trainer)深度解析 _get_per_token_logps:如何计算序列中每个 Token 的对数概率
这是一个辅助函数,用来计算模型对 [prompt + completion] 每个 token 的对数概率。

  • 先调用 model(...).logits,再用 gather 提取实际 token 的 logit,再减去 logsumexp 得到 log(prob)。
  • 同时做了“移除最后一列 logit”、“只保留后面若干 token logits”等兼容处理。

4.3 prediction_step

在评估或预测时,也需要执行 _prepare_inputs 来生成多条回答并算 loss,只不过不会再反向传播。

  • 这个函数就是在 eval 阶段或预测时,对应地拿到 loss,用于打印日志或 early stopping 等操作。

4.4 log & create_model_card

  • log:将一些 metrics,比如 KL、reward、completion length 等写到日志。
  • create_model_card:辅助你在做完训练后生成一份 README,包含 base model 名称、训练配置信息、引用等。

5. 总结

grpo_train.py 中的 GRPOTrainer 是一个覆盖了 Trainer 部分方法、带有 GRPO 逻辑的训练器,专为“大模型 + 强化学习”场景设计。它的主要流程可以简述为:

  1. 数据处理 / 采样:对 prompt 生成多条回答,打分并算相对优势;
  2. loss 计算:引入参考模型做 KL 惩罚,把相对优势带入类似 PPO 的公式;
  3. 训练循环:父类 Trainer 提供的 train() 方法会自动按 batch 调用 _prepare_inputs & compute_loss,完成 RL 更新。

从而在没有显式价值函数的情况下,通过“分组比较 + KL 惩罚”高效地优化语言模型回答质量。

如果你想使用 GRPO 来做大语言模型强化学习,只需要准备好数据集(必须包含 prompt 字段)、一个或多个奖励函数(可以是预训练好的分类模型,也可以是自定义 Python 函数),再配合 GRPOTrainer 的超参数,即可快速开始训练。

后记

2025年2月22日15点57分于上海。在GPT4o大模型辅助下完成。

一、目的1. 加速训练过程2. 适应大规模数据3. 资源利用率高4. 提升训练速度5. 增大系统容量6. 提高系统可用性7. 加速模型迭代二、 LLaMA-Factory1.安装2. LLaMA-Factory 校验三、 训练引擎1.DDP2. DeepSpeed3.FSDP四、WebUI五. 参数配置1. 模型2. 数据3. 训练参数4. 多卡参数1. ZeRO-12. ZeRO-23. ZeRO-3六、训练七、推理八、XTuner一、目的分布式训练是一种在多个计算节点上共同完成机器学习模型训练任务的过程,它可以充分利用多台计算机的资源,提高训练效率和模型准确性。分布式训练的主要优势包括:1. 加速训练过程通过并行计算,分布式训练大幅缩短了训练时间,提高了训练效率。提高模型准确性:利用更多的计算资源和数据样本进行训练,减少了过拟合风险,提高了模型的泛化能力和准确性。2. 适应大规模数据分布式训练能够处理传统单机训练难以应对的大规模数据集。3. 资源利用率高有效利用了计算资源,避免了单机训练时的资源闲置和浪费。4. 提升训练速度通过并行计算,分布式训练能够显著缩短模型训练的时间,尤其是在处理大规模数据集和复杂模型时效果更为明显。5. 增大系统容量随着业务量的增长,单机性能已无法满足需求。分布式训练通过多台计算设备的协同工作,能够应对更大规模的应用场景。6. 提高系统可用性分布式架构能够消除单点故障,提高系统的整体可用性。即使某个计算设备出现故障,也不会影响整个训练任务的进行。7. 加速模型迭代在快速迭代的机器学习项目中,分布式训练能够更快地完成模型训练,从而加速模型迭代和优化过程。总的来说,分布式训练在深度学习领域提高训练效率和加快模型收敛的重要手段 。二、 LLaMA-Factory1.安装在安装 LLaMA-Factory 之前,请确保您安装了下列依赖:运行以下指令以安装 LLaMA-Factory 及其依赖:git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.gitcd LLaMA-Factorypip install -e ".[torch,metrics]"123如果出现环境冲突,请尝试使用 pip install --no-deps -e . 解决2. LLaMA-Factory 校验完成安装后,可以通过使用 llamafactory-cli version 来快速校验安装是否成功如果看到类似下面的界面,就说明安装成功了。 Successfully uninstalled requests-2.31.0 Attempting uninstall: anyio Found existing installation: anyio 4.4.0 Uninstalling anyio-4.4.0: Successfully uninstalled anyio-4.4.0Successfully installed accelerate-1.2.1 aiofiles-23.2.1 aiohappyeyeballs-2.4.6 aiohttp-3.11.12 aiosignal-1.3.2 annotated-types-0.7.0 anyio-4.8.0 audioread-3.0.1 av-14.1.0 click-8.1.8 datasets-3.2.0 dill-0.3.8 docstring-parser-0.16 einops-0.8.1 fastapi-0.115.8 ffmpy-0.5.0 fire-0.7.0 frozenlist-1.5.0 gradio-5.12.0 gradio-client-1.5.4 huggingface-hub-0.28.1 jieba-0.42.1 joblib-1.4.2 lazy-loader-0.4 librosa-0.10.2.post1 llamafactory-0.9.2.dev0 llvmlite-0.44.0 markdown-it-py-3.0.0 mdurl-0.1.2 msgpack-1.1.0 multidict-6.1.0 multiprocess-0.70.16 nltk-3.9.1 numba-0.61.0 orjson-3.10.15 pandas-2.2.3 peft-0.12.0 pooch-1.8.2 propcache-0.2.1 pyarrow-19.0.0 pydantic-2.10.6 pydantic-core-2.27.2 pydub-0.25.1 python-multipart-0.0.20 pytz-2025.1 regex-2024.11.6 requests-2.32.3 rich-13.9.4 rouge-chinese-1.0.3 ruff-0.9.6 safehttpx-0.1.6 safetensors-0.5.2 scikit-learn-1.6.1 scipy-1.15.1 semantic-version-2.10.0 sentencepiece-0.2.0 shellingham-1.5.4 shtab-1.7.1 soundfile-0.13.1 soxr-0.5.0.post1 sse-starlette-2.2.1 starlette-0.45.3 termcolor-2.5.0 threadpoolctl-3.5.0 tiktoken-0.9.0 tokenizers-0.21.0 tomlkit-0.13.2 tqdm-4.67.1 transformers-4.48.3 trl-0.9.6 typer-0.15.1 typing-extensions-4.12.2 tyro-0.8.14 tzdata-2025.1 uvicorn-0.34.0 websockets-14.2 xxhash-3.5.0 yarl-1.18.3WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venvroot@autodl-container-c2d74383d9-db8bb7c4:~/autodl-tmp/LLaMA-Factory# llamafactory-cli version----------------------------------------------------------| Welcome to LLaMA Factory, version 0.9.2.dev0 || || Project page: https://github.com/hiyouga/LLaMA-Factory |----------------------------------------------------------root@autodl-container-c2d74383d9-db8bb7c4:~/autodl-tmp/LLaMA-Factory# 1234567891011121314三、 训练引擎LLaMA-Factory 支持单机多卡和多机多卡分布式训练。同时也支持 DDP , DeepSpeed 和 FSDP 三种分布式引擎。1.DDPDDP (DistributedDataParallel) 通过实现模型并行和数据并行实现训练加速。 使用 DDP 的程序需要生成多个进程并且为每个进程创建一个 DDP 实例,他们之间通过 torch.distributed 库同步。2. DeepSpeedDeepSpeed 是微软开发的分布式训练引擎,并提供ZeRO(Zero Redundancy Optimizer)、offload、Sparse Attention、1 bit Adam、流水线并行等优化技术。 您可以根据任务需求与设备选择使用。3.FSDP通过全切片数据并行技术(Fully Sharded Data Parallel)来处理更多更大的模型。在 DDP 中,每张 GPU 都各自保留了一份完整的模型参数和优化器参数。而 FSDP 切分了模型参数、梯度与优化器参数,使得每张 GPU 只保留这些参数的一部分。 除了并行技术之外,FSDP 还支持将模型参数卸载至CPU,从而进一步降低显存需求。由于deepseek分布式训练加速,采用混合精度(fp16/fp32)和ZeRO优化,减少显存占用,从而加速训练。所以本文采用DeepSpeed 是训练引擎。四、WebUILLaMA-Factory 支持通过 WebUI 零代码微调大语言模型。 在完成 安装 后,您可以通过以下指令进入 WebUI:llamafactory-cli webui1WebUI 主要分为四个界面:训练、评估与预测、对话、导出。当运行上面命令后,打开如下界面在开始训练模型之前,需要指定的参数有:模型名称及路径训练阶段微调方法训练数据集学习率、训练轮数等训练参数微调参数等其他参数输出目录及配置路径
最新发布
03-20
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值