DeepSeek-R1-Zero 的训练过程:pytorch代码实现

DeepSeek-R1-Zero 的训练过程是基于纯强化学习(Reinforcement Learning, RL)的方法,不依赖监督微调(Supervised Fine-Tuning, SFT)作为预备步骤。以下是 DeepSeek-R1-Zero 训练过程的详细说明,基于第 2.2 节(DeepSeek-R1-Zero: Reinforcement Learning on the Base Model)的内容:


DeepSeek-R1-Zero 的训练过程

1. 起点:基于基础模型
  • 基础模型:DeepSeek-R1-Zero 以 DeepSeek-V3-Base 作为基础模型。
  • 无监督数据:训练从一开始就避免使用任何监督数据,目标是通过纯 RL 过程让模型自我进化,探索其推理能力的潜力。
2. 强化学习算法:Group Relative Policy Optimization (GRPO)
  • 算法选择:使用 GRPO(Shao et al., 2024)作为 RL 框架。GRPO 是一种高效的策略优化方法,相较于传统 RL(如 PPO),它不依赖与策略模型同等规模的 critic 模型,而是通过组评分(group scores)估算基线,从而降低训练成本。
  • 优化目标:GRPO 的目标是通过采样一组输出并根据奖励优化策略模型。具体公式如下:
    • 对于每个问题 ( q q q ),从旧策略 ( π θ old \pi_{\theta_{\text{old}}} πθold ) 中采样一组输出 ( { o 1 , o 2 , ⋯   , o G } \{o_1, o_2, \cdots, o_G\} {o1,o2,,oG} );
    • 优化新策略 ( π θ \pi_\theta πθ ) 的目标是最大化以下函数:
      J G R P O ( θ ) = E [ q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ old ( O ∣ q ) ] [ 1 G ∑ i = 1 G ( min ⁡ ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) A i , clip ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) , 1 − ε , 1 + ε ) A i ) − β D K L ( π θ ∥ π ref ) ) ] \mathcal{J}_{GRPO}(\theta) = \mathbf{E} \left[ q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(O \mid q) \right] \left[ \frac{1}{G} \sum_{i=1}^G \left( \min \left( \frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_{\text{old}}}(o_i \mid q)} A_i, \text{clip} \left( \frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_{\text{old}}}(o_i \mid q)}, 1-\varepsilon, 1+\varepsilon \right) A_i \right) - \beta \mathbb{D}_{KL}(\pi_\theta \| \pi_{\text{ref}}) \right) \right] JGRPO(θ)=E[qP(Q),{oi}i=1Gπθold(Oq)][G1i=1G(min(πθold(oiq)πθ(oiq)Ai,clip(πθold(oiq)πθ(oiq),1ε,1+ε)Ai)βDKL(πθπref))]
      其中:
      • ( A i A_i Ai ) 是优势函数,根据一组奖励 ( { r 1 , r 2 , … , r G } \{r_1, r_2, \ldots, r_G\} {r1,r2,,rG} ) 计算:
        A i = r i − mean ( { r 1 , r 2 , ⋯   , r G } ) std ( { r 1 , r 2 , ⋯   , r G } ) A_i = \frac{r_i - \text{mean}(\{r_1, r_2, \cdots, r_G\})}{\text{std}(\{r_1, r_2, \cdots, r_G\})} Ai=std({r1,r2,,rG})rimean({r1,r2,,rG})
      • ( D K L \mathbb{D}_{KL} DKL ) 是 KL 散度,用于正则化以防止策略偏离参考策略 ( π ref \pi_{\text{ref}} πref ) 过远;
      • ( ε \varepsilon ε ) 和 ( β \beta β ) 是超参数,用于控制剪切范围和 KL 散度的权重。
3. 奖励模型(Reward Modeling)
  • 奖励来源:奖励是训练信号的核心,决定模型优化的方向。DeepSeek-R1-Zero 使用基于规则的奖励系统,包括两种类型:
    1. 准确性奖励(Accuracy Rewards)
      • 用于评估回答是否正确。例如,对于数学问题,要求模型以指定格式(如框住最终答案)输出结果,便于规则验证;对于 LeetCode 问题,则通过编译器基于预定义测试用例生成反馈。
    2. 格式奖励(Format Rewards)
      • 强制模型将推理过程放在 <think></think> 标签之间,以确保输出结构符合预期。
  • 避免神经奖励模型:文档提到,未使用基于神经网络的奖励模型(如过程或结果神经奖励模型),因为在大规模 RL 中这些模型容易导致“奖励黑客”(reward hacking),且重新训练奖励模型会增加资源消耗和复杂性。
4. 训练模板(Training Template)
  • 模板设计:训练使用一个简单模板(见Table 1),要求模型首先生成推理过程(标记为 <think></think>),然后提供最终答案(标记为 <answer></answer>)。
    在这里插入图片描述

  • 无内容偏见:模板仅限制结构,不强制要求特定推理策略(如反思或特定解题方法),以观察模型在 RL 过程中的自然演化。

5. 训练过程与结果
  • 性能提升:在 RL 训练中,DeepSeek-R1-Zero 的性能逐步提高。例如,在 AIME 2024 基准上的 pass@1 分数从初始的 15.6% 提升到 71.0%,通过多数投票(majority voting)进一步达到 86.7%,媲美 OpenAI-o1-0912。

  • 自我进化:模型通过 RL 自然发展出复杂行为,如:

    • 延长推理时间:生成更长的推理链(从数百到数千个 token),以解决更复杂的任务。
    • 反思与替代策略:自发出现反思(重新评估先前步骤)和探索替代解法的能力。
  • “Aha Moment”:训练中观察到“顿悟时刻”,如 Table 3 所示,模型学会重新评估初始方法并调整策略,显示出类人化的推理语气。
    在这里插入图片描述

  • 训练步骤:经过数千次 RL 迭代,模型在推理基准上表现出超强性能。

6. 挑战与局限
  • 问题:尽管推理能力强大,DeepSeek-R1-Zero 存在可读性差(poor readability)和语言混合(language mixing)等问题。这些问题推动了后续 DeepSeek-R1 的开发,引入冷启动数据和多阶段训练来解决。

总结

DeepSeek-R1-Zero 的训练过程是一个从基础模型 DeepSeek-V3-Base 开始,基于 GRPO 算法的大型纯强化学习实验。通过规则驱动的准确性和格式奖励,结合简单的训练模板,模型在不依赖监督数据的情况下,逐步自我进化,发展出强大的推理能力(如长链推理、反思等)。其性能在 AIME 等基准上显著提升,但也暴露出可读性和语言一致性方面的不足,为后续改进提供了方向。

DeepSeek-R1-Zero训练代码实现

为 DeepSeek-R1-Zero 的训练过程,尤其是奖励模型(Reward Modeling)部分,提供代码实现。这部分是基于规则的奖励系统,包括准确性奖励(Accuracy Rewards)和格式奖励(Format Rewards),并且避免使用神经奖励模型。这里提供一个简化的 PyTorch 实现,重点体现奖励模型的设计和集成。


设计思路

  1. 奖励模型的目标
    • 准确性奖励:检查模型输出是否正确,例如数学问题的最终答案是否与标准答案匹配,或 LeetCode 问题的代码是否通过测试用例。
    • 格式奖励:确保输出符合 <think>推理过程</think><answer>答案</answer> 的结构。
  2. 实现方式
    • 使用规则逻辑(而不是神经网络)计算奖励。
    • 将奖励融入 GRPO 的训练循环,用于计算优势函数 ( A i A_i Ai ) 和优化策略。
  3. 代码结构
    • 定义一个奖励计算函数,处理准确性和格式两部分。
    • 将其嵌入 GRPO 的损失计算中。

以下是代码实现:


代码实现

import torch
import torch.nn as nn
import re
from typing import List, Tuple

# 假设的策略网络(Actor),输出 token 序列的分布
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, vocab_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, vocab_size)  # 输出每个 token 的 logits
        )
    
    def forward(self, states):
        return self.net(states)  # 返回 logits

# 奖励模型:基于规则计算准确性和格式奖励
def compute_reward(output: str, ground_truth: str = None, problem_type: str = "math") -> float:
    """
    计算单个输出的奖励。
    - output: 模型生成的文本
    - ground_truth: 标准答案(数学问题提供,LeetCode 可为空)
    - problem_type: "math" 或 "leetcode"
    """
    # 初始化奖励
    total_reward = 0.0
    
    # 1. 格式奖励:检查是否符合 <think></think><answer></answer> 结构
    think_pattern = r"<think>.*?</think>"
    answer_pattern = r"<answer>.*?</answer>"
    has_think = bool(re.search(think_pattern, output))
    has_answer = bool(re.search(answer_pattern, output))
    
    format_reward = 1.0 if (has_think and has_answer) else 0.0
    total_reward += format_reward  # 格式奖励权重为 1.0

    # 2. 准确性奖励
    if problem_type == "math" and ground_truth:
        # 提取 <answer> 中的答案
        answer_match = re.search(r"<answer>(.*?)</answer>", output)
        if answer_match:
            predicted_answer = answer_match.group(1).strip()
            # 检查是否与标准答案匹配(假设答案是数字或简单字符串)
            accuracy_reward = 1.0 if predicted_answer == ground_truth else 0.0
            total_reward += accuracy_reward  # 准确性奖励权重为 1.0
        else:
            total_reward += 0.0  # 无答案则准确性奖励为 0
    
    elif problem_type == "leetcode":
        # 模拟 LeetCode 测试用例验证(这里简化为编译器检查逻辑)
        # 假设 output 是代码,调用外部函数验证(伪代码)
        from compiler import run_tests  # 假设的外部编译器接口
        test_result = run_tests(output)  # 返回 True/False 表示是否通过测试
        accuracy_reward = 1.0 if test_result else 0.0
        total_reward += accuracy_reward

    return total_reward

# GRPO 训练核心逻辑
def compute_grpo_loss(
    policy: PolicyNetwork,
    old_policy: PolicyNetwork,
    ref_policy: PolicyNetwork,
    states: torch.Tensor,
    outputs: List[str],
    rewards: List[float],
    epsilon: float = 0.2,
    beta: float = 0.01
) -> torch.Tensor:
    """
    计算 GRPO 损失。
    - policy: 当前策略网络
    - old_policy: 旧策略网络(用于采样和比率计算)
    - ref_policy: 参考策略(用于 KL 散度)
    - states: 输入状态(问题描述的 embedding)
    - outputs: 模型生成的输出序列
    - rewards: 每个输出的奖励
    """
    G = len(outputs)  # 组大小
    assert G > 0, "Group size must be positive"

    # 计算新旧策略的概率比和 KL 散度
    policy_logits = policy(states)  # [batch_size, seq_len, vocab_size]
    old_logits = old_policy(states).detach()  # 旧策略不更新梯度
    ref_logits = ref_policy(states).detach()  # 参考策略固定

    # 假设 outputs 已转换为 token IDs,这里简化为对数概率的近似
    # 在实际中需要序列化处理,这里用占位符
    log_probs = torch.softmax(policy_logits, dim=-1).log()  # 模拟 log π_θ(o|q)
    old_log_probs = torch.softmax(old_logits, dim=-1).log()  # log π_θ_old(o|q)
    ref_log_probs = torch.softmax(ref_logits, dim=-1).log()  # log π_ref(o|q)

    # 计算概率比 r_t(θ)
    ratios = torch.exp(log_probs - old_log_probs)  # π_θ(o|q) / π_θ_old(o|q)

    # 计算优势 A_i
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    mean_reward = rewards_tensor.mean()
    std_reward = rewards_tensor.std() + 1e-6  # 避免除零
    advantages = (rewards_tensor - mean_reward) / std_reward

    # 剪切损失项
    surr1 = ratios * advantages
    surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
    clipped_loss = torch.min(surr1, surr2).mean()

    # KL 散度正则化项
    kl_div = (ref_log_probs - log_probs).mean()  # 简化的 KL 散度计算

    # 总损失
    loss = clipped_loss - beta * kl_div
    return -loss  # 最大化 J_GRPO 等价于最小化 -loss

# 训练循环
def train_deepseek_r1_zero(
    policy: PolicyNetwork,
    optimizer: torch.optim.Optimizer,
    problems: List[Tuple[str, str]],  # (问题描述, 标准答案)
    num_epochs: int = 100,
    group_size: int = 16
):
    ref_policy = PolicyNetwork(...)  # 初始化参考策略(固定)
    old_policy = PolicyNetwork(...)  # 初始化旧策略
    old_policy.load_state_dict(policy.state_dict())  # 同步初始参数

    for epoch in range(num_epochs):
        for problem, ground_truth in problems:
            # 准备输入状态(假设问题已转为 embedding)
            state = torch.tensor(...)  # 简化为占位符

            # 从旧策略采样一组输出
            outputs = []
            for _ in range(group_size):
                with torch.no_grad():
                    logits = old_policy(state)
                    # 假设通过采样生成序列(简化处理)
                    output = "sampled_output"  # 实际需解码为文本
                    outputs.append(output)

            # 计算奖励
            rewards = [compute_reward(output, ground_truth, "math") for output in outputs]

            # 计算 GRPO 损失并更新策略
            optimizer.zero_grad()
            loss = compute_grpo_loss(policy, old_policy, ref_policy, state, outputs, rewards)
            loss.backward()
            optimizer.step()

            # 更新旧策略
            old_policy.load_state_dict(policy.state_dict())

        print(f"Epoch {epoch}, Loss: {loss.item()}")

# 示例调用
policy = PolicyNetwork(input_dim=512, hidden_dim=1024, vocab_size=30000)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
problems = [("Solve 2+2", "4"), ("Solve 3*3", "9")]
train_deepseek_r1_zero(policy, optimizer, problems)

代码说明

1. 奖励模型(compute_reward
  • 准确性奖励
    • 对于数学问题,使用正则表达式提取 <answer> 中的答案,与 ground_truth 比较。如果匹配,奖励为 1.0,否则为 0.0。
    • 对于 LeetCode 问题,假设有一个外部函数 run_tests(这里是伪代码),检查代码是否通过测试用例。
  • 格式奖励
    • 使用正则表达式检查 <think><answer> 标签是否存在,符合要求则奖励 1.0,否则 0.0。
  • 总奖励
    • 简单相加准确性和格式奖励(权重均为 1.0,可根据需要调整)。
  • 避免神经奖励模型
    • 实现完全基于规则逻辑,不依赖任何神经网络,确保简单性和可控性,避免 reward hacking。
2. GRPO 损失(compute_grpo_loss
  • 输入
    • outputs 是模型生成的文本序列,rewards 是对应奖励。
  • 优势计算
    • 根据公式 ( A i = r i − mean std A_i = \frac{r_i - \text{mean}}{\text{std}} Ai=stdrimean ),直接用组内奖励统计计算。
  • 剪切和 KL
    • 保留 PPO 的剪切机制,同时加入 ( D K L ( π θ ∣ ∣ π ref ) D_{KL}(\pi_\theta || \pi_{\text{ref}}) DKL(πθ∣∣πref) )。
  • 简化假设
    • 这里假设 log_probs 从 logits 计算,实际语言模型需处理序列生成(例如用 transformers 的 generate 方法)。
3. 训练循环(train_deepseek_r1_zero
  • 采样:从旧策略生成一组输出(这里简化为占位符,实际需实现 token-by-token 生成)。
  • 奖励计算:调用 compute_reward 为每个输出评分。
  • 优化:计算 GRPO 损失并更新策略。

注意事项

  1. 简化之处
    • 代码中省略了序列生成的细节(例如如何将 logits 转为文本)。实际中需结合 transformers 库,使用 model.generate() 并计算每步的概率。
    • state 的生成(问题转为 embedding)未实现,可用预训练模型(如 BERT)编码问题。
  2. 可扩展性
    • 可以为不同问题类型添加更多规则(例如科学推理、逻辑题)。
    • 调整奖励权重(例如格式奖励占 0.3,准确性占 0.7)。
  3. 性能优化
    • 批量处理多组样本以加速训练。
    • 使用分布式计算支持大规模 RL。

如何体现原论文中的设计?

  • 规则奖励compute_reward 严格遵循文档描述,使用正则表达式和逻辑判断实现准确性和格式检查。
  • 避免神经模型:奖励计算不涉及任何神经网络,完全基于规则,符合文档避免 reward hacking 的目标。
  • GRPO 集成:奖励值直接用于优势计算,驱动策略优化,与 DeepSeek-R1-Zero 的训练过程一致。

PPO和GRPO的区别

PPO 中的 Critic Model(Value Model)的作用

在 PPO(以及许多 Actor-Critic 类型的强化学习算法)中,critic model(通常称为价值模型)是一个独立的神经网络,用于估计状态的价值函数 ( V ( s ) V(s) V(s) )。它的主要作用是为策略优化提供一个基线(baseline),从而计算优势函数 ( A ( s , a ) A(s, a) A(s,a) ),帮助减少策略梯度估计的方差,提升训练的稳定性。

1. Critic Model 如何使用?
  • 价值函数估计:Critic model 的任务是预测给定状态 ( s s s ) 下的期望累计回报,即 ( V π ( s ) = E [ ∑ t γ t r t ∣ s 0 = s ] V^\pi(s) = \mathbb{E} [ \sum_t \gamma^t r_t | s_0 = s ] Vπ(s)=E[tγtrts0=s] ),其中 ( γ \gamma γ ) 是折扣因子,( r t r_t rt ) 是奖励。
  • 与策略模型的关系:在 PPO 中,策略模型(Actor)负责输出动作分布 ( π ( a ∣ s ) \pi(a|s) π(as) ),而 Critic model 提供价值估计,二者共同优化。
  • 训练数据:Critic model 通常通过时序差分(TD)学习或蒙特卡洛回报来更新,使用实际获得的回报 ( R R R )(如 GAE,Generalized Advantage Estimation)作为监督信号。
2. Critic Model 在计算优势(Advantage)中的作用

是的,Critic model 的核心作用之一就是用于计算优势函数 ( A ( s , a ) A(s, a) A(s,a) )。优势函数衡量的是在状态 ( s s s ) 下选择动作 ( a a a ) 相比于平均情况(即状态价值 ( V ( s ) V(s) V(s) ))的额外收益。PPO 中常用的计算方式是:

A ( s , a ) = Q ( s , a ) − V ( s ) A(s, a) = Q(s, a) - V(s) A(s,a)=Q(s,a)V(s)

其中:

  • ( Q ( s , a ) Q(s, a) Q(s,a) ) 是动作价值函数,表示在状态 ( s s s ) 下执行动作 ( a a a ) 后的期望回报;
  • ( V ( s ) V(s) V(s) ) 是状态价值函数,由 Critic model 预测。

在实践中,直接计算 ( Q ( s , a ) Q(s, a) Q(s,a) ) 需要完整的环境动态,而这通常不可行。因此,PPO 通常使用 时序差分误差广义优势估计(GAE) 来近似 ( A ( s , a ) A(s, a) A(s,a) ):

  • 简单 TD 误差
    A ( s , a ) = r + γ V ( s ′ ) − V ( s ) A(s, a) = r + \gamma V(s') - V(s) A(s,a)=r+γV(s)V(s)
    其中 ( r r r ) 是即时奖励,( s ′ s' s ) 是下一状态。
  • GAE(更常用)
    A t = ∑ l = 0 ∞ ( γ λ ) l δ t + l , 其中 δ t = r t + γ V ( s t + 1 ) − V ( s t ) A_t = \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l}, \quad \text{其中} \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) At=l=0(γλ)lδt+l,其中δt=rt+γV(st+1)V(st)
    这里 ( λ \lambda λ ) 是 GAE 的超参数,用于平衡偏差和方差。

Critic model 通过输出 ( V ( s ) V(s) V(s) ) 和 ( V ( s ′ ) V(s') V(s) ),为这些计算提供基线,使得优势函数能够反映动作的相对优劣。

3. PPO 的损失函数中如何体现?

PPO 的损失函数由两部分组成:策略损失(policy loss)价值损失(value loss),分别优化 Actor 和 Critic,加上一个可选的熵正则化项。完整的 PPO 损失函数如下:

L P P O ( θ , ϕ ) = L p o l i c y ( θ ) + c 1 L v a l u e ( ϕ ) − c 2 S [ π θ ] ( s ) L^{PPO}(\theta, \phi) = L^{policy}(\theta) + c_1 L^{value}(\phi) - c_2 S[\pi_\theta](s) LPPO(θ,ϕ)=Lpolicy(θ)+c1Lvalue(ϕ)c2S[πθ](s)

  • ( θ \theta θ ) 和 ( ϕ \phi ϕ ):分别表示策略网络(Actor)和价值网络(Critic)的参数。
  • ( c 1 , c 2 c_1, c_2 c1,c2 ):超参数,平衡各部分的权重。
  • ( S [ π θ ] ( s ) S[\pi_\theta](s) S[πθ](s) ):策略的熵,用于鼓励探索。
(1) 策略损失 ( L p o l i c y ( θ ) L^{policy}(\theta) Lpolicy(θ) )

策略损失使用剪切(clipped)比率来限制策略更新幅度,避免过大的步长:
L p o l i c y ( θ ) = E [ min ⁡ ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ε , 1 + ε ) A ^ t ) ] L^{policy}(\theta) = \mathbb{E} \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\varepsilon, 1+\varepsilon) \hat{A}_t \right) \right] Lpolicy(θ)=E[min(rt(θ)A^t,clip(rt(θ),1ε,1+ε)A^t)]
其中:

  • ( r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(atst)πθ(atst) ) 是新旧策略的概率比;
  • ( A ^ t \hat{A}_t A^t ) 是优势估计,由 Critic model 提供(如上所述,通过 GAE 或 TD 误差计算);
  • ( ε \varepsilon ε ) 是剪切范围的超参数。

体现 Critic 的地方:这里的 ( A ^ t \hat{A}_t A^t ) 直接依赖 Critic model 的输出 ( V ( s ) V(s) V(s) ),Critic 的准确性影响优势估计的质量,从而影响策略优化的方向和稳定性。

(2) 价值损失 ( L v a l u e ( ϕ ) L^{value}(\phi) Lvalue(ϕ) )

Critic model 的训练目标是最小化预测价值 ( V ϕ ( s ) V_\phi(s) Vϕ(s) ) 与实际回报(目标价值)的误差:
L v a l u e ( ϕ ) = E [ ( V ϕ ( s t ) − R t ) 2 ] L^{value}(\phi) = \mathbb{E} \left[ (V_\phi(s_t) - R_t)^2 \right] Lvalue(ϕ)=E[(Vϕ(st)Rt)2]
其中:

  • ( R t R_t Rt ) 是目标回报,可以是蒙特卡洛回报或 GAE 计算的回报;
  • 这部分损失直接优化 Critic model,使其更好地估计状态价值。
(3) 熵项(可选)

熵项 ( − c 2 S [ π θ ] ( s ) -c_2 S[\pi_\theta](s) c2S[πθ](s) ) 与 Critic 无关,用于防止策略过早收敛到局部最优。

总结 PPO 中 Critic 的体现
  • 计算优势:Critic model 输出 ( V ( s ) V(s) V(s) ),用于计算 ( A ( s , a ) A(s, a) A(s,a) ),这是策略梯度更新的核心信号。
  • 独立优化:Critic 有自己的损失函数 ( L v a l u e L^{value} Lvalue ),通过监督学习方式与回报对齐。
  • 整体作用:Critic 减少了策略梯度的方差,使得 PPO 能够在高维动作空间中更稳定地训练。

GRPO 与 PPO 的区别:为什么不用 Critic Model?

根据文档(第 2.2.1 节),GRPO(Group Relative Policy Optimization)放弃了 PPO 中的 Critic model,而是通过组评分(group scores)来估计基线。具体差异如下:

1. GRPO 的优势计算
  • 无 Critic:GRPO 不依赖独立的 Critic model 来预测 ( V ( s ) V(s) V(s) )。相反,它通过对一组输出 ( { o 1 , o 2 , … , o G } \{o_1, o_2, \ldots, o_G\} {o1,o2,,oG} ) 的奖励 ( { r 1 , r 2 , … , r G } \{r_1, r_2, \ldots, r_G\} {r1,r2,,rG} ) 进行统计计算来估计优势:
    A i = r i − mean ( { r 1 , r 2 , ⋯   , r G } ) std ( { r 1 , r 2 , ⋯   , r G } ) A_i = \frac{r_i - \text{mean}(\{r_1, r_2, \cdots, r_G\})}{\text{std}(\{r_1, r_2, \cdots, r_G\})} Ai=std({r1,r2,,rG})rimean({r1,r2,,rG})
    这里,基线是组内奖励的均值,方差归一化用于标准化。
  • 优点:这种方法避免了训练一个单独的 Critic model,减少了计算开销和参数量,尤其在大规模 RL 中(如 DeepSeek-R1-Zero 的训练)非常重要。
2. GRPO 的损失函数

GRPO 的损失函数与 PPO 类似,但基线由组内统计替代:
J G R P O ( θ ) = E [ 1 G ∑ i = 1 G ( min ⁡ ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) A i , clip ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) , 1 − ε , 1 + ε ) A i ) − β D K L ( π θ ∥ π ref ) ) ] \mathcal{J}_{GRPO}(\theta) = \mathbb{E} \left[ \frac{1}{G} \sum_{i=1}^G \left( \min \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip} \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\varepsilon, 1+\varepsilon \right) A_i \right) - \beta \mathbb{D}_{KL}(\pi_\theta \| \pi_{\text{ref}}) \right) \right] JGRPO(θ)=E[G1i=1G(min(πθold(oiq)πθ(oiq)Ai,clip(πθold(oiq)πθ(oiq),1ε,1+ε)Ai)βDKL(πθπref))]

  • 无 ( L v a l u e L^{value} Lvalue ):由于没有 Critic,GRPO 不需要额外的价值损失项。
  • 依赖采样:优势 ( A i A_i Ai ) 完全基于采样的奖励统计,而不是模型预测。
3. 为什么 GRPO 不用 Critic?
  • 降低成本:文档明确提到,Critic model 通常与策略模型同等规模,训练它会显著增加计算负担。GRPO 通过组采样直接估计基线,简化了流程。
  • 适用场景:在 DeepSeek-R1-Zero 的训练中,奖励(如准确性奖励和格式奖励)是基于规则的,易于计算,不需要复杂的价值估计,适合用统计方法替代 Critic。

PPO 和 GRPO 的对比

特性PPOGRPO
Critic Model有,独立的价值网络估计 ( V ( s ) V(s) V(s) )无,使用组奖励统计作为基线
优势计算( A = r + γ V ( s ′ ) − V ( s ) A = r + \gamma V(s') - V(s) A=r+γV(s)V(s) ) 或 GAE( A i = r i − mean std A_i = \frac{r_i - \text{mean}}{\text{std}} Ai=stdrimean )
损失函数包含策略损失 + 价值损失 + 熵项仅策略损失 + KL 正则化,无价值损失
计算复杂度较高(两网络训练)较低(单网络 + 采样统计)

总结

  1. PPO 中的 value model 是如何使用的?
    • Value model(Critic)用于预测状态价值 ( V ( s ) V(s) V(s) ),作为基线计算优势 ( A ( s , a ) A(s, a) A(s,a) ),并通过 ( L v a l u e L^{value} Lvalue ) 单独优化。
  2. 是用在算优势那部分吗?
    • 是的,Critic 的核心作用是为优势函数提供 ( V ( s ) V(s) V(s) ) 和 ( V ( s ′ ) V(s') V(s) ),如 ( A = r + γ V ( s ′ ) − V ( s ) A = r + \gamma V(s') - V(s) A=r+γV(s)V(s) ) 或 GAE。
  3. 在 PPO 的 loss 中是怎么体现的?
    • 体现在两处:(1) 策略损失中的 ( A ^ t \hat{A}_t A^t ) 依赖 Critic 的输出;(2) 价值损失 ( L v a l u e L^{value} Lvalue ) 直接优化 Critic 的预测。

PPO 中的 KL 散度:在哪里体现?

在 PPO 中,KL 散度并不是直接作为一个独立的损失项加到最终的损失函数中,而是通过策略更新时的约束机制隐式地起作用。PPO 的核心思想是通过限制新旧策略之间的偏差来保证训练的稳定性,而 KL 散度正是用来衡量这种偏差的工具。具体来说,PPO 使用两种方式来实现这一约束:

1. Clipped Surrogate Objective(剪切代理目标)

PPO 的标准形式(也是最常见的形式)使用剪切(clipping)来限制新旧策略之间的变化,而不是直接在损失函数中添加 KL 散度项。其损失函数如下:

L p o l i c y ( θ ) = E [ min ⁡ ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ε , 1 + ε ) A ^ t ) ] L^{policy}(\theta) = \mathbb{E} \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\varepsilon, 1+\varepsilon) \hat{A}_t \right) \right] Lpolicy(θ)=E[min(rt(θ)A^t,clip(rt(θ),1ε,1+ε)A^t)]

其中:

  • ( r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(atst)πθ(atst) ) 是新旧策略的概率比;
  • ( ε \varepsilon ε ) 是剪切范围的超参数(通常为 0.1 或 0.2)。

KL 散度的隐式体现

  • ( r t ( θ ) r_t(\theta) rt(θ) ) 表示新策略 ( π θ \pi_\theta πθ ) 和旧策略 ( π θ old \pi_{\theta_{\text{old}}} πθold ) 在动作概率上的相对变化。当 ( r t ( θ ) r_t(\theta) rt(θ) ) 偏离 1 过多时(即新旧策略差异过大),剪切操作会限制更新幅度。
  • 这种剪切本质上是对新旧策略分布差异的一种近似控制,而 KL 散度正是衡量两个概率分布差异的数学工具。虽然公式中没有显式写出 ( D K L ( π θ ∣ ∣ π θ old ) D_{KL}(\pi_\theta || \pi_{\theta_{\text{old}}}) DKL(πθ∣∣πθold) ),但 ( r t ( θ ) r_t(\theta) rt(θ) ) 的剪切间接限制了 KL 散度的大小。
  • 在 PPO 的原始论文(Schulman et al., 2017)中,作者提到剪切目标是对信任区域(trust region)的一种简化,而信任区域通常由 KL 散度定义。
2. 自适应 KL 惩罚形式(Adaptive KL Penalty,较少使用)

PPO 的另一种变体(在论文中也有提及,但实践中不常见)通过显式添加 KL 散度惩罚项来替代剪切机制。这种形式的损失函数如下:

L p o l i c y ( θ ) = E [ r t ( θ ) A ^ t − β D K L ( π θ old ∣ ∣ π θ ) ] L^{policy}(\theta) = \mathbb{E} \left[ r_t(\theta) \hat{A}_t - \beta D_{KL}(\pi_{\theta_{\text{old}}} || \pi_\theta) \right] Lpolicy(θ)=E[rt(θ)A^tβDKL(πθold∣∣πθ)]

  • ( β \beta β ) 是一个动态调整的惩罚系数,根据实际的 KL 散度大小自适应变化:
    • 如果 ( D K L D_{KL} DKL ) 超过目标值(例如 0.01),则增大 ( β \beta β );
    • 如果 ( D K L D_{KL} DKL ) 小于目标值,则减小 ( β \beta β )。

这种形式直接使用了 KL 散度,但在实践中,剪切形式(clipped objective)更流行,因为它更简单且不需要手动调整 ( β \beta β )。

为什么标准 PPO 损失中没有显式 KL 项?
  • 标准 PPO(剪切版本)选择用剪切机制替代显式 KL 散度惩罚,因为:
    1. 计算简便:剪切只需计算概率比 ( r t ( θ ) r_t(\theta) rt(θ) ),而 KL 散度需要对整个分布求和或积分,计算成本更高。
    2. 稳定性:剪切提供了一个硬约束,效果直观且易于实现,避免了 ( β \beta β ) 自适应调节的不确定性。
  • 因此,尽管 KL 散度的概念在 PPO 的设计中至关重要(限制策略更新步长),但它通过剪切机制被“隐藏”在了实现中,而不是显式出现在损失公式里。

GRPO 中的 KL 散度:显式体现

相比之下,GRPO(Group Relative Policy Optimization)在损失函数中显式地包含了 KL 散度项,如你提供的公式:

J G R P O ( θ ) = E [ 1 G ∑ i = 1 G ( min ⁡ ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) A i , clip ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) , 1 − ε , 1 + ε ) A i ) − β D K L ( π θ ∣ ∣ π ref ) ) ] \mathcal{J}_{GRPO}(\theta) = \mathbb{E} \left[ \frac{1}{G} \sum_{i=1}^G \left( \min \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip} \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\varepsilon, 1+\varepsilon \right) A_i \right) - \beta D_{KL}(\pi_\theta || \pi_{\text{ref}}) \right) \right] JGRPO(θ)=E[G1i=1G(min(πθold(oiq)πθ(oiq)Ai,clip(πθold(oiq)πθ(oiq),1ε,1+ε)Ai)βDKL(πθ∣∣πref))]

  • 作用:这里的 ( D K L ( π θ ∣ ∣ π ref ) D_{KL}(\pi_\theta || \pi_{\text{ref}}) DKL(πθ∣∣πref) ) 是一个正则化项,限制新策略 ( π θ \pi_\theta πθ ) 偏离参考策略 ( π ref \pi_{\text{ref}} πref )(而不是旧策略 ( π θ old \pi_{\theta_{\text{old}}} πθold ))的程度。
  • 与 PPO 的区别
    • GRPO 同时保留了剪切机制(类似 PPO)和显式 KL 项,提供了双重约束。
    • ( π ref \pi_{\text{ref}} πref ) 通常是一个固定的参考分布(例如初始策略),而 PPO 只关注 ( π θ old \pi_{\theta_{\text{old}}} πθold )(上一步的策略)。

PPO 中的 KL 散度代码实现

以下是用 PyTorch 实现的 PPO(剪切版本)核心部分的伪代码,展示 KL 散度的隐式约束如何通过剪切实现:

import torch
import torch.nn as nn

# 假设策略网络和价值网络已经定义
class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.actor = nn.Sequential(...)  # 输出动作分布参数(如均值和方差)
        self.critic = nn.Sequential(...) # 输出状态价值 V(s)

# 计算 PPO 损失
def compute_ppo_loss(actor_critic, states, actions, old_log_probs, advantages, returns, epsilon=0.2):
    # 前向传播:获取新策略的 log_prob 和价值估计
    dist = actor_critic.actor(states)  # 假设输出分布(如正态分布)
    new_log_probs = dist.log_prob(actions)
    values = actor_critic.critic(states)

    # 计算概率比 r_t(θ)
    ratios = torch.exp(new_log_probs - old_log_probs)  # r_t = π_θ(a|s) / π_θ_old(a|s)

    # 策略损失(clipped objective)
    surr1 = ratios * advantages
    surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
    policy_loss = -torch.min(surr1, surr2).mean()

    # 价值损失
    value_loss = ((values - returns) ** 2).mean()

    # 总损失(无显式 KL 项)
    total_loss = policy_loss + 0.5 * value_loss  # 0.5 是超参数 c1

    return total_loss

# 训练循环(伪代码)
def train_ppo(actor_critic, optimizer, data):
    states, actions, old_log_probs, advantages, returns = data
    optimizer.zero_grad()
    loss = compute_ppo_loss(actor_critic, states, actions, old_log_probs, advantages, returns)
    loss.backward()
    optimizer.step()

# 可选:计算实际 KL 散度(仅用于监控,不影响损失)
def compute_kl_divergence(dist_new, dist_old):
    return torch.distributions.kl_divergence(dist_new, dist_old).mean()
关键点解释:
  1. 剪切机制
    • ratios 表示 ( r t ( θ ) r_t(\theta) rt(θ) ),通过 torch.clamp 限制在 ( [ 1 − ε , 1 + ε ] [1-\varepsilon, 1+\varepsilon] [1ε,1+ε] ) 范围内。
    • 这隐式地控制了新旧策略的 KL 散度,避免过大更新。
  2. 无显式 KL 项
    • 损失函数中没有直接计算 ( D K L D_{KL} DKL ),但可以通过 compute_kl_divergence 单独监控(例如打印日志)。
  3. 实现细节
    • old_log_probs 是从旧策略采样的动作的对数概率,需要在新策略更新前保存。
    • dist 是新策略输出的分布(如 torch.distributions.Normal)。
如果使用自适应 KL 惩罚形式:

以下是 PPO 的 KL 惩罚版本的伪代码:

def compute_ppo_kl_loss(actor_critic, states, actions, old_dist, advantages, beta=0.01):
    dist = actor_critic.actor(states)
    new_log_probs = dist.log_prob(actions)
    ratios = torch.exp(new_log_probs - old_dist.log_prob(actions))

    # 策略损失
    policy_loss = -(ratios * advantages).mean()

    # KL 散度惩罚
    kl_div = torch.distributions.kl_divergence(dist, old_dist).mean()
    total_loss = policy_loss + beta * kl_div

    return total_loss
  • 显式 KL:这里直接计算 ( D K L D_{KL} DKL ) 并加到损失中,( β \beta β ) 需要根据 KL 大小动态调整。

总结

  • PPO 中的 KL 散度
    • 在标准剪切版本中,KL 散度通过 ( r t ( θ ) r_t(\theta) rt(θ) ) 的剪切隐式体现,不出现在损失公式中。
    • 在 KL 惩罚版本中,KL 散度显式加到损失函数中,但实践中较少使用。
  • GRPO 中的 KL 散度
    • 显式出现在损失公式中,作为正则化项约束新策略偏离参考策略。
  • 代码实现
    • 剪切版本简单高效,只需计算概率比并剪切;
    • KL 惩罚版本需要额外计算分布间的 KL 散度,复杂度稍高。

PPO 中的 Clip 为什么间接代替了 KL 散度?

在 PPO 中,剪切机制(clipped surrogate objective)是为了实现信任区域(trust region)优化的简单近似,而信任区域通常是通过 KL 散度来定义的:

  • 信任区域的意义:强化学习中,策略更新需要在新旧策略之间保持一定的一致性,避免过大的步长导致性能崩溃。KL 散度 ( D K L ( π θ old ∣ ∣ π θ ) D_{KL}(\pi_{\theta_{\text{old}}} || \pi_\theta) DKL(πθold∣∣πθ) ) 是衡量新旧策略分布差异的自然选择。
  • PPO 的剪切设计
    • PPO 使用概率比 ( r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(atst)πθ(atst) ) 的剪切(限制在 ( [ 1 − ε , 1 + ε ] [1-\varepsilon, 1+\varepsilon] [1ε,1+ε] ))来近似控制新旧策略的偏差。
    • 当 ( r t ( θ ) r_t(\theta) rt(θ)) 偏离 1 过多时(即新策略偏离旧策略过远),剪切会限制梯度更新幅度。这种限制本质上是对 KL 散度的间接约束,因为 ( r t ( θ ) r_t(\theta) rt(θ) ) 的变化与 KL 散度密切相关(例如,当分布差异小时,( r t ( θ ) ≈ 1 + D K L r_t(\theta) \approx 1 + D_{KL} rt(θ)1+DKL ) 的线性近似)。
  • 为什么不用显式 KL?
    • 计算 KL 散度需要对整个分布求和或积分,成本较高,而剪切只需计算单点概率比,简单高效。
    • PPO 的目标是提供一个易于实现的算法,剪切机制在实践中效果很好,且避免了手动调整 KL 惩罚系数(如 ( β \beta β ))的不确定性。

因此,PPO 的剪切机制是对信任区域的“代理”(surrogate),间接实现了 KL 散度的约束功能。


GRPO 为什么同时有 Clip 和 KL 散度?

GRPO 的损失函数如下:

J G R P O ( θ ) = E [ 1 G ∑ i = 1 G ( min ⁡ ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) A i , clip ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) , 1 − ε , 1 + ε ) A i ) − β D K L ( π θ ∣ ∣ π ref ) ) ] \mathcal{J}_{GRPO}(\theta) = \mathbb{E} \left[ \frac{1}{G} \sum_{i=1}^G \left( \min \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip} \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\varepsilon, 1+\varepsilon \right) A_i \right) - \beta D_{KL}(\pi_\theta || \pi_{\text{ref}}) \right) \right] JGRPO(θ)=E[G1i=1G(min(πθold(oiq)πθ(oiq)Ai,clip(πθold(oiq)πθ(oiq),1ε,1+ε)Ai)βDKL(πθ∣∣πref))]

GRPO 同时保留剪切和显式 KL 散度的原因可以从以下几个方面理解:

1. 双重约束的不同目标
  • Clip 的作用
    • 类似于 PPO,剪切限制了新策略 ( π θ \pi_\theta πθ ) 相对于旧策略 ( π θ old \pi_{\theta_{\text{old}}} πθold ) 的单步更新幅度,确保每次迭代不会偏离过远。这是局部信任区域的保障,关注的是“当前更新”的稳定性。
  • KL 散度的作用
    • GRPO 中的 ( D K L ( π θ ∣ ∣ π ref ) D_{KL}(\pi_\theta || \pi_{\text{ref}}) DKL(πθ∣∣πref) ) 是相对于一个固定的参考策略 ( π ref \pi_{\text{ref}} πref )(通常是初始策略或某个基准策略)的全局约束。它防止新策略在整个训练过程中偏离初始分布太远,起到长期正则化的作用。
  • 区别
    • Clip 关注的是短期的、迭代间的稳定性(( π θ \pi_\theta πθ ) vs. ( π θ old \pi_{\theta_{\text{old}}} πθold ))。
    • KL 散度关注的是长期的、相对于初始状态的稳定性(( π θ \pi_\theta πθ ) vs. ( π ref \pi_{\text{ref}} πref ))。
2. GRPO 的设计背景
  • 去掉 Critic 的影响:GRPO 放弃了 PPO 中的 Critic model,使用组内奖励统计来计算优势 ( A i A_i Ai )。这种设计减少了计算成本,但可能导致策略更新方向不够稳定(因为没有价值函数提供基线)。因此,GRPO 通过双重约束(clip + KL)来增强稳定性。
  • 语言模型的复杂性:在 DeepSeek-R1-Zero 的训练场景中,输出空间(如自然语言 token 序列)比传统 RL(如游戏)的动作空间更复杂。显式 KL 散度可以防止模型生成过于离谱的输出(例如语言混合或不可读内容),这在文档中提到的 DeepSeek-R1-Zero 的局限性中有所体现。
3. 实际效果
  • 互补性:剪切和 KL 散度的组合提供了更强的控制:
    • 剪切确保每次更新不会过于激进;
    • KL 散度防止模型在长时间训练后“漂移”到不可控的状态。
  • 实验驱动:GRPO 的设计可能基于实验观察,发现单独使用剪切不足以应对大规模语言模型训练中的挑战(如 reward hacking 或分布崩溃),因此引入显式 KL 项作为额外保障。

PPO 可以同时拥有 Clip 和 KL 散度吗?

理论上,PPO 完全可以同时拥有剪切和显式 KL 散度,而且这种组合在某些变体中已经被提出或实验过。下面是分析和实现思路:

1. 可行性分析
  • 理论依据
    • PPO 的剪切机制是对信任区域的近似,而显式 KL 散度是信任区域的直接表达。结合起来可以提供更细粒度的控制。
    • 例如,可以用剪切限制单步更新,用 KL 散度正则化全局分布,类似于 GRPO 的思路。
  • 潜在优势
    • 在高维或复杂任务中(如语言生成),仅靠剪切可能不足以防止策略分布的退化,添加 KL 散度可以增强鲁棒性。
  • 潜在问题
    • 超参数调节:需要同时调整剪切范围 ( ε \varepsilon ε ) 和 KL 惩罚系数 ( β \beta β ),增加了调参难度。
    • 计算开销:显式计算 KL 散度会增加计算负担,尤其是对于连续分布或高维离散分布。
2. PPO + Clip + KL 的损失函数

如果在 PPO 中加入显式 KL 散度,损失函数可能如下:

L P P O + K L ( θ ) = E [ min ⁡ ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ε , 1 + ε ) A ^ t ) − β D K L ( π θ ∣ ∣ π ref ) ] + c 1 L v a l u e ( ϕ ) L^{PPO+KL}(\theta) = \mathbb{E} \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\varepsilon, 1+\varepsilon) \hat{A}_t \right) - \beta D_{KL}(\pi_\theta || \pi_{\text{ref}}) \right] + c_1 L^{value}(\phi) LPPO+KL(θ)=E[min(rt(θ)A^t,clip(rt(θ),1ε,1+ε)A^t)βDKL(πθ∣∣πref)]+c1Lvalue(ϕ)

  • 与 GRPO 的区别
    • PPO 保留了 Critic model 和价值损失 ( L v a l u e L^{value} Lvalue ),而 GRPO 没有。
    • ( π ref \pi_{\text{ref}} πref ) 可以选择为 ( π θ old \pi_{\theta_{\text{old}}} πθold )(动态参考旧策略)或固定的初始策略。
3. 代码实现示例

以下是用 PyTorch 实现的 PPO + Clip + KL 的伪代码:

import torch
import torch.nn as nn

class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.actor = nn.Sequential(...)  # 输出动作分布
        self.critic = nn.Sequential(...) # 输出状态价值

def compute_ppo_with_kl_loss(actor_critic, states, actions, old_log_probs, advantages, returns, ref_dist, epsilon=0.2, beta=0.01):
    # 新策略分布和价值
    dist = actor_critic.actor(states)
    new_log_probs = dist.log_prob(actions)
    values = actor_critic.critic(states)

    # 概率比和剪切损失
    ratios = torch.exp(new_log_probs - old_log_probs)
    surr1 = ratios * advantages
    surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
    policy_loss = -torch.min(surr1, surr2).mean()

    # 价值损失
    value_loss = ((values - returns) ** 2).mean()

    # 显式 KL 散度(相对于参考分布)
    kl_div = torch.distributions.kl_divergence(dist, ref_dist).mean()

    # 总损失
    total_loss = policy_loss + 0.5 * value_loss + beta * kl_div

    return total_loss

# 训练循环
def train_ppo_with_kl(actor_critic, optimizer, data, ref_dist):
    states, actions, old_log_probs, advantages, returns = data
    optimizer.zero_grad()
    loss = compute_ppo_with_kl_loss(actor_critic, states, actions, old_log_probs, advantages, returns, ref_dist)
    loss.backward()
    optimizer.step()
  • 关键点
    • ref_dist 是参考分布,可以是初始策略的输出分布(需预先保存)。
    • kl_div 使用 PyTorch 的 kl_divergence 函数计算,适用于常见分布(如正态分布或离散分布)。
4. PPO 不常用 Clip + KL 的原因
  • 冗余性:剪切已经很好地实现了信任区域约束,添加 KL 散度可能收益不大,反而增加复杂度。
  • 简单性优先:PPO 的设计目标是简单易用,剪切机制在大多数任务中足够稳定,显式 KL 散度显得多余。
  • 任务依赖:在语言模型等复杂任务中,显式 KL 可能更有价值,但在传统 RL 任务(如 Atari 游戏)中,剪切已足够。

总结

  • 为什么 GRPO 既有 Clip 又有 KL?
    • GRPO 去掉了 Critic,依赖双重约束(剪切控制单步稳定性,KL 控制全局正则化)来应对复杂任务(如语言生成)的挑战。
  • PPO 可以同时有两者吗?
    • 可以,理论上可行且实现简单,但实践中很少这样做,因为剪切已足够,且增加 KL 散度会提高计算成本和调参难度。
  • 区别的根源
    • GRPO 的设计适应了大规模语言模型训练的需求(无 Critic、高维输出),而 PPO 更通用,倾向于简单高效。

后记

2025年3月3日13点42分于上海,在grok3大模型辅助下完成。

### 如何对 DeepSeek-R1 模型进行训练和微调 #### 使用 LoRA 和 Unsloth 优化 DeepSeek-R1 的微调过程 为了有效降低资源消耗并提高效率,在消费级硬件上可以采用 LoRA(低秩自适应)以及 Unsloth 来优化 DeepSeek-R1 的微调流程[^1]。 LoRA 是一种参数高效的迁移学习方法,它通过引入少量可训练参数来调整预训练模型的行为。这种方法允许只更新一小部分权重而不是整个网络的所有连接,从而显著减少了所需的计算量和内存占用。对于想要利用现有高性能 GPU 而不是依赖昂贵的专业设备的人来说尤其有用。 Unsloth 则是一个专门设计用于加速深度神经网络训练速度的库,能够进一步提升基于 PyTorch 或 TensorFlow 构建的应用程序的表现。结合这两个工具可以帮助用户实现在普通计算机上的高效微调操作。 #### 准备工作与环境配置 在开始之前,确保已经安装好 Python 及其相关依赖项,并设置了一个合适的虚拟环境。接着按照官方文档说明下载所需的数据集及预处理脚本[^3]。 ```bash pip install torch torchvision transformers loralib unsloth ``` 这段命令会安装必要的软件包,包括 PyTorch、Transformers 库以及其他辅助组件。 #### 实施具体步骤 加载预训练好的 DeepSeek-R1 模型实例: ```python from transformers import AutoModelForCausalLM, AutoTokenizer model_name_or_path = "path_to_deepseek_r1" tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model = AutoModelForCausalLM.from_pretrained(model_name_or_path).cuda() ``` 应用 LoRA 修改策略到选定层上: ```python import loralib as lora lora_config = { 'r': 8, 'alpha': 16, } for name, module in model.named_modules(): if isinstance(module, nn.Linear): new_module = lora.LoRALinear( in_features=module.in_features, out_features=module.out_features, r=lora_config['r'], alpha=lora_config['alpha'] ) setattr(model, name, new_module.cuda()) ``` 此代码片段遍历了模型内部所有的线性变换模块并将它们替换为带有 LoRA 参数的新版本。 最后一步是定义损失函数、优化器及其超参设定,启动实际训练循环: ```python optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) # 假设 data_loader 已经被正确定义过... for epoch in range(num_epochs): for batch in data_loader: inputs = tokenizer(batch["text"], return_tensors="pt").to('cuda') outputs = model(**inputs, labels=inputs.input_ids) loss = outputs.loss optimizer.zero_grad() loss.backward() optimizer.step() torch.save({ 'epoch': num_epochs, 'model_state_dict': model.state_dict(), }, f'./checkpoint_{num_epochs}.pth') ``` 上述代码展示了完整的训练周期逻辑框架;当然还需要根据实际情况调整细节之处比如批次大小、迭代次数等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值