Open-R1 项目代码文件的详细剖析

目录

1. configs.py

功能概述

关键代码与细节

2. evaluate.py

功能概述

关键代码与细节

3. generate.py

功能概述

关键代码与细节

4. grpo.py

功能概述

关键代码与细节

5. rewards.py

功能概述

关键代码与细节

6. sft.py

功能概述

关键代码与细节

安装

训练模型

评估模型

复现DeepSeek的评估结果

MATH-500

GPQA Diamond

数据生成流程


技术实现与细节

以下是对提供的代码文件的详细剖析,结合代码内容和项目背景,分析其功能、实现细节和应用场景。

1. configs.py

功能概述

configs.py 文件定义了两种配置类:GRPOConfigSFTConfig,分别用于 GRPO(Group Relative Policy Optimization)训练和 SFT(Supervised Fine-Tuning)训练。这些配置类继承自 trl(Transformers Reinforcement Learning)库中的基础配置类,并添加了一些额外的参数。

关键代码与细节
  • GRPOConfig 和 SFTConfig

    @dataclass
    class GRPOConfig(trl.GRPOConfig):
        benchmarks: list[str] = field(
            default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
        )
        callbacks: list[str] = field(
            default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
        )
        system_prompt: Optional[str] = field(
            default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
        )
        hub_model_revision: Optional[str] = field(
            default="main", metadata={"help": "The Hub model branch to push the model to."}
        )
        overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
        push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
    • 继承关系GRPOConfigSFTConfig 继承自 trl.GRPOConfigtrl.SFTConfig,扩展了这些类的功能。

    • 新增参数

      • benchmarks:训练后运行的基准测试列表。

      • callbacks:训练过程中运行的回调函数列表。

      • system_prompt:用于基准测试的系统提示。

      • hub_model_revision:推送模型到 Hugging Face Hub 的分支。

      • overwrite_hub_revisionpush_to_hub_revision:控制是否覆盖或推送模型版本。

  • 应用场景

    • 这些配置类用于定义训练和评估的参数,支持用户自定义训练流程中的各种设置,如基准测试、回调函数和模型版本管理。

2. evaluate.py

功能概述

evaluate.py 文件定义了自定义的评估任务,用于在 LightEval 框架中评估模型的性能。这些任务包括数学推理、问答等。

关键代码与细节
  • 评估指标

    latex_gold_metric = multilingual_extractive_match_metric(
        language=Language.ENGLISH,
        fallback_mode="first_match",
        precision=5,
        gold_extraction_target=(LatexExtractionConfig(),),
        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
        aggregation_function=max,
    )
    • multilingual_extractive_match_metric:一个多语言的提取匹配指标,用于评估模型生成的内容是否与参考答案匹配。

    • gold_extraction_targetpred_extraction_target:定义了从参考答案和模型生成内容中提取信息的配置。

  • 提示函数

    def prompt_fn(line, task_name: str = None):
        return Doc(
            task_name=task_name,
            query=line["problem"],
            choices=[line["solution"]],
            gold_index=0,
        )
    • prompt_fn:生成评估任务的提示,用于数学推理任务。

    • aime_prompt_fngpqa_prompt_fn:分别为 AIME 和 GPQA 任务生成提示。

  • 任务定义

    aime24 = LightevalTaskConfig(
        name="aime24",
        suite=["custom"],
        prompt_function=aime_prompt_fn,
        hf_repo="HuggingFaceH4/aime_2024",
        hf_subset="default",
        hf_avail_splits=["train"],
        evaluation_splits=["train"],
        few_shots_split=None,
        few_shots_select=None,
        generation_size=32768,
        metric=[expr_gold_metric],
        version=1,
    )
    • LightevalTaskConfig:定义了一个评估任务的配置,包括任务名称、提示函数、数据集、评估指标等。

    • TASKS_TABLE:将所有定义的任务存储在一个列表中,便于管理和运行。

  • 应用场景

    • 该文件用于定义和运行模型的评估任务,支持多种数学推理和问答任务,帮助用户评估模型在不同领域的性能。

3. generate.py

功能概述

generate.py 文件定义了一个用于生成数据的管道,使用 distilabel 工具从模型中生成合成数据。

关键代码与细节
  • 构建管道

    def build_distilabel_pipeline(
        model: str,
        base_url: str = "http://localhost:8000/v1",
        prompt_column: Optional[str] = None,
        prompt_template: str = "{{ instruction }}",
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        max_new_tokens: int = 8192,
        num_generations: int = 1,
        input_batch_size: int = 64,
        client_replicas: int = 1,
        timeout: int = 900,
        retries: int = 0,
    ) -> Pipeline:
        ...
    • build_distilabel_pipeline:构建一个 distilabel 管道,用于生成数据。

    • 参数

      • model:用于生成数据的模型名称。

      • base_url:模型服务器的 URL。

      • prompt_columnprompt_template:定义提示的列和模板。

      • temperaturetop_p:生成的温度和核采样参数。

      • max_new_tokens:生成的最大新 token 数量。

      • num_generations:每个输入生成的样本数量。

      • input_batch_size:输入的批量大小。

      • client_replicas:客户端副本数量,用于并行处理。

      • timeoutretries:请求超时和重试次数。

  • 主函数

    if __name__ == "__main__":
        parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
        ...
        args = parser.parse_args()
        ...
        pipeline = build_distilabel_pipeline(
            model=args.model,
            base_url=args.vllm_server_url,
            prompt_template=args.prompt_template,
            prompt_column=args.prompt_column,
            temperature=args.temperature,
            top_p=args.top_p,
            max_new_tokens=args.max_new_tokens,
            num_generations=args.num_generations,
            input_batch_size=args.input_batch_size,
            client_replicas=args.client_replicas,
            timeout=args.timeout,
            retries=args.retries,
        )
        ...
        distiset = pipeline.run(
            dataset=dataset,
            dataset_batch_size=args.input_batch_size * 1000,
            use_cache=False,
        )
        ...
    • 命令行参数:通过 argparse 解析命令行参数,支持用户自定义生成数据的配置。

    • 数据加载:使用 datasets 加载数据集。

    • 管道运行:运行生成管道,生成合成数据并保存到 Hugging Face Hub。

  • 应用场景

    • 该文件用于生成合成数据,支持用户自定义生成配置,适用于模型训练和数据增强。

4. grpo.py

功能概述

grpo.py 文件实现了 GRPO(Group Relative Policy Optimization)训练流程,用于优化模型的策略。

关键代码与细节
  • GRPOScriptArguments

    @dataclass
    class GRPOScriptArguments(ScriptArguments):
        reward_funcs: list[str] = field(
            default_factory=lambda: ["accuracy", "format"],
            metadata={
                "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"
            },
        )
        cosine_min_value_wrong: float = field(
            default=0.0,
            metadata={"help": "Minimum reward for wrong answers"},
        )
        ...
    • reward_funcs:定义奖励函数列表,支持多种奖励函数,如准确率、格式、推理步骤、余弦缩放和重复惩罚。

    • cosine_min_value_wrong 等参数:定义余弦缩放奖励的参数。

  • 主函数

    def main(script_args, training_args, model_args):
        ...
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
        ...
        reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
        ...
        trainer = GRPOTrainer(
            model=model_args.model_name_or_path,
            reward_funcs=reward_funcs,
            args=training_args,
            train_dataset=dataset[script_args.dataset_train_split],
            eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
            peft_config=get_peft_config(model_args),
            callbacks=get_callbacks(training_args, model_args),
        )
        ...
    • 数据加载:使用 datasets 加载训练和评估数据集。

    • 奖励函数:根据用户指定的奖励函数,加载相应的函数。

    • GRPOTrainer:初始化 GRPO 训练器,设置模型、奖励函数、训练参数等。

    • 训练循环:运行训练循环,支持从断点恢复训练。

  • 应用场景

    • 该文件用于 GRPO 训练,支持多种奖励函数和训练配置,适用于优化模型的策略。

5. rewards.py

功能概述

rewards.py 文件定义了多种奖励函数,用于在 GRPO 训练中评估模型生成的内容。

关键代码与细节
  • 奖励函数

    def accuracy_reward(completions, solution, **kwargs):
        ...
        reward = float(verify(answer_parsed, gold_parsed))
        ...
    • accuracy_reward:检查模型生成的内容是否与参考答案一致,返回 1 或 0。

    • format_reward:检查生成内容是否符合特定格式。

    • reasoning_steps_reward:检查生成内容是否包含清晰的推理步骤。

    • cosine_scaled_reward:基于生成内容长度的余弦缩放奖励。

    • repetition_penalty_reward:基于重复 n-gram 的惩罚奖励。

  • 应用场景

    • 这些奖励函数用于 GRPO 训练,帮助模型生成更准确、更符合格式、更具推理性和更少重复的内容。

6. sft.py

功能概述

sft.py 文件实现了 SFT(Supervised Fine-Tuning)训练流程,用于对模型进行有监督微调。

关键代码与细节
  • 主函数

    def main(script_args, training_args, model_args):
        ...
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
        ...
        trainer = SFTTrainer(
            model=model_args.model_name_or_path,
            args=training_args,
            train_dataset=dataset[script_args.dataset_train_split],
            eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
            processing_class=tokenizer,
            peft_config=get_peft_config(model_args),
            callbacks=get_callbacks(training_args, model_args),
        )
        ...
    • 数据加载:使用 datasets 加载训练和评估数据集。

    • SFTTrainer:初始化 SFT 训练器,设置模型、训练参数、分词器等。

    • 训练循环:运行训练循环,支持从断点恢复训练。

  • 应用场景

    • 该文件用于 SFT 训练,支持多种训练配置和回调函数,适用于对模型进行有监督微调。

功能模块

  • 模型训练

    • SFT(Supervised Fine-Tuning):对预训练模型进行微调,使其更好地适应特定任务。例如,在指令微调中,将小样本数据集用于微调,使模型生成更符合人类常识的对话内容。

    • GRPO(Group-Relative Policy Optimization):使用 GRPO 方法对模型进行 RL(强化学习)培训。该方法基于代理与环境之间的交互,通过最大化累积奖励信号来训练策略模型。

  • 模型评估

    • 使用 lighteval 对模型进行评估,lighteval 是一种轻量级的评估工具,支持多种评估任务。例如,在 AIME 2024、MATH-500 和 GPQA Diamond 等任务上对模型进行测试,得到准确率等评估指标,以评估模型的性能。

  • 数据生成

    • 从 smol 蒸馏 R1 模型生成数据:使用轻量级的蒸馏 R1 模型生成数据。该模块通过 Distilabel 来生成合成数据,为模型训练提供更多样化的数据。

    • 从 DeepSeek-R1 生成数据:使用更大的 DeepSeek-R1 模型生成数据。这需要更多的计算资源,但可以生成更高质量的合成数据,以支持更复杂的模型训练和测试。

安装

[!CAUTION]
相关库依赖于CUDA 12.4。如果您看到与段错误相关的错误,请使用nvcc --version仔细检查您的系统正在运行的CUDA版本。

要运行这个项目中的代码,首先,使用例如uv创建一个Python虚拟环境。
要安装uv,请参考UV安装指南

uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --upgrade pip --link-mode=copy

接下来,安装vLLM:

uv pip install vllm==0.7.1 --link-mode=copy

这也会安装PyTorch v2.5.1,使用这个版本非常重要,因为vLLM的二进制文件是针对该版本编译的。然后,您可以通过pip install -e .[LIST OF MODES]安装特定用例的其余依赖项。对于大多数贡献者,我们建议:

GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]" --link-mode=copy

接下来,按如下方式登录您的Hugging Face和Weights and Biases账户:

 

训练模型

我们支持使用数据并行分布式训练(DDP)或DeepSpeed(ZeRO-2和ZeRO-3)来训练模型。例如,要在从DeepSeek-R1提炼的带有推理痕迹的数据集(如Bespoke-Stratos-17k)上运行监督微调(SFT),请运行以下命令:

# 通过命令行进行训练
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
    --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
    --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
    --learning_rate 2.0e-5 \
    --num_train_epochs 1 \
    --packing \
    --max_seq_length 4096 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --bf16 \
    --output_dir data/Qwen2.5-1.5B-Open-R1-Distill

# 通过YAML配置文件进行训练
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
    --config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml

目前,支持以下任务:

评估模型

make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=data NUM_GPUS=8

要使用张量并行:

make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=tensor NUM_GPUS=8

复现DeepSeek的评估结果

MATH-500

我们能够在约1 - 3个标准差范围内复现DeepSeek在MATH-500基准测试上报告的结果:

模型MATH-500(🤗 LightEval)MATH-500(DeepSeek报告值)
DeepSeek-R1-Distill-Qwen-1.5B81.283.9
DeepSeek-R1-Distill-Qwen-7B91.892.8
DeepSeek-R1-Distill-Qwen-14B94.293.9
DeepSeek-R1-Distill-Qwen-32B95.094.3
DeepSeek-R1-Distill-Llama-8B85.489.1
DeepSeek-R1-Distill-Llama-70B93.494.5

要复现这些结果,请使用以下命令:

NUM_GPUS=1 # 对于32B和70B模型,设置为8
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilisation=0.8,tensor_parallel_size=$NUM_GPUS"
OUTPUT_DIR=data/evals/$MODEL

lighteval vllm $MODEL_ARGS "custom|math_500|0|0" \
    --custom-tasks src/open_r1/evaluate.py \
    --use-chat-template \
    --output-dir $OUTPUT_DIR

GPQA Diamond

lighteval vllm $MODEL_ARGS "custom|gpqa:diamond|0|0" \
    --custom-tasks src/open_r1/evaluate.py \
    --use-chat-template \
    --output-dir $OUTPUT_DIR
python scripts/run_benchmarks.py --model-id={model_id}  --benchmarks gpqa
数据生成流程
  • 小模型蒸馏数据生成

    • 使用轻量级蒸馏 R1 模型生成数据。通过 Distilabel 工具,从预定义的提示模板和数据集出发,生成合成数据。

    • 例如,使用 DeepSeek-R1 的蒸馏 Qwen-7B 模型生成数学推理数据,将数据保存到远程数据集中,并可通过华为 MindSpore 加载该数据集以用于训练。

  • 大模型数据生成

    • 使用更大的 DeepSeek-R1 模型生成数据,需要更多的计算资源。通过 Slurm 脚本(如 slurm/generate.slurm)在集群上运行生成任务,可以高效地生成大规模合成数据。

    • 生成过程中,可以通过设置温度(如 0.6)、提示列(如 “problem”)等参数来控制生成数据的质量和多样性。

更多可参照GitHub - huggingface/open-r1: Fully open reproduction of DeepSeek-R1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI仙人掌

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值