DeepSeek-R1-Zero 的训练过程是基于纯强化学习(Reinforcement Learning, RL)的方法,不依赖监督微调(Supervised Fine-Tuning, SFT)作为预备步骤。以下是 DeepSeek-R1-Zero 训练过程的详细说明,基于第 2.2 节(DeepSeek-R1-Zero: Reinforcement Learning on the Base Model)的内容:
DeepSeek-R1-Zero 的训练过程
1. 起点:基于基础模型
- 基础模型:DeepSeek-R1-Zero 以 DeepSeek-V3-Base 作为基础模型。
- 无监督数据:训练从一开始就避免使用任何监督数据,目标是通过纯 RL 过程让模型自我进化,探索其推理能力的潜力。
2. 强化学习算法:Group Relative Policy Optimization (GRPO)
- 算法选择:使用 GRPO(Shao et al., 2024)作为 RL 框架。GRPO 是一种高效的策略优化方法,相较于传统 RL(如 PPO),它不依赖与策略模型同等规模的 critic 模型,而是通过组评分(group scores)估算基线,从而降低训练成本。
- 优化目标:GRPO 的目标是通过采样一组输出并根据奖励优化策略模型。具体公式如下:
- 对于每个问题 ( q q q ),从旧策略 ( π θ old \pi_{\theta_{\text{old}}} πθold ) 中采样一组输出 ( { o 1 , o 2 , ⋯ , o G } \{o_1, o_2, \cdots, o_G\} {o1,o2,⋯,oG} );
- 优化新策略 (
π
θ
\pi_\theta
πθ ) 的目标是最大化以下函数:
J G R P O ( θ ) = E [ q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ old ( O ∣ q ) ] [ 1 G ∑ i = 1 G ( min ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) A i , clip ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) , 1 − ε , 1 + ε ) A i ) − β D K L ( π θ ∥ π ref ) ) ] \mathcal{J}_{GRPO}(\theta) = \mathbf{E} \left[ q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(O \mid q) \right] \left[ \frac{1}{G} \sum_{i=1}^G \left( \min \left( \frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_{\text{old}}}(o_i \mid q)} A_i, \text{clip} \left( \frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_{\text{old}}}(o_i \mid q)}, 1-\varepsilon, 1+\varepsilon \right) A_i \right) - \beta \mathbb{D}_{KL}(\pi_\theta \| \pi_{\text{ref}}) \right) \right] JGRPO(θ)=E[q∼P(Q),{oi}i=1G∼πθold(O∣q)][G1i=1∑G(min(πθold(oi∣q)πθ(oi∣q)Ai,clip(πθold(oi∣q)πθ(oi∣q),1−ε,1+ε)Ai)−βDKL(πθ∥πref))]
其中:- (
A
i
A_i
Ai ) 是优势函数,根据一组奖励 (
{
r
1
,
r
2
,
…
,
r
G
}
\{r_1, r_2, \ldots, r_G\}
{r1,r2,…,rG} ) 计算:
A i = r i − mean ( { r 1 , r 2 , ⋯ , r G } ) std ( { r 1 , r 2 , ⋯ , r G } ) A_i = \frac{r_i - \text{mean}(\{r_1, r_2, \cdots, r_G\})}{\text{std}(\{r_1, r_2, \cdots, r_G\})} Ai=std({r1,r2,⋯,rG})ri−mean({r1,r2,⋯,rG}) - ( D K L \mathbb{D}_{KL} DKL ) 是 KL 散度,用于正则化以防止策略偏离参考策略 ( π ref \pi_{\text{ref}} πref ) 过远;
- ( ε \varepsilon ε ) 和 ( β \beta β ) 是超参数,用于控制剪切范围和 KL 散度的权重。
- (
A
i
A_i
Ai ) 是优势函数,根据一组奖励 (
{
r
1
,
r
2
,
…
,
r
G
}
\{r_1, r_2, \ldots, r_G\}
{r1,r2,…,rG} ) 计算:
3. 奖励模型(Reward Modeling)
- 奖励来源:奖励是训练信号的核心,决定模型优化的方向。DeepSeek-R1-Zero 使用基于规则的奖励系统,包括两种类型:
- 准确性奖励(Accuracy Rewards):
- 用于评估回答是否正确。例如,对于数学问题,要求模型以指定格式(如框住最终答案)输出结果,便于规则验证;对于 LeetCode 问题,则通过编译器基于预定义测试用例生成反馈。
- 格式奖励(Format Rewards):
- 强制模型将推理过程放在
<think>
和</think>
标签之间,以确保输出结构符合预期。
- 强制模型将推理过程放在
- 准确性奖励(Accuracy Rewards):
- 避免神经奖励模型:文档提到,未使用基于神经网络的奖励模型(如过程或结果神经奖励模型),因为在大规模 RL 中这些模型容易导致“奖励黑客”(reward hacking),且重新训练奖励模型会增加资源消耗和复杂性。
4. 训练模板(Training Template)
-
模板设计:训练使用一个简单模板(见Table 1),要求模型首先生成推理过程(标记为
<think>
和</think>
),然后提供最终答案(标记为<answer>
和</answer>
)。
-
无内容偏见:模板仅限制结构,不强制要求特定推理策略(如反思或特定解题方法),以观察模型在 RL 过程中的自然演化。
5. 训练过程与结果
-
性能提升:在 RL 训练中,DeepSeek-R1-Zero 的性能逐步提高。例如,在 AIME 2024 基准上的 pass@1 分数从初始的 15.6% 提升到 71.0%,通过多数投票(majority voting)进一步达到 86.7%,媲美 OpenAI-o1-0912。
-
自我进化:模型通过 RL 自然发展出复杂行为,如:
- 延长推理时间:生成更长的推理链(从数百到数千个 token),以解决更复杂的任务。
- 反思与替代策略:自发出现反思(重新评估先前步骤)和探索替代解法的能力。
-
“Aha Moment”:训练中观察到“顿悟时刻”,如 Table 3 所示,模型学会重新评估初始方法并调整策略,显示出类人化的推理语气。
-
训练步骤:经过数千次 RL 迭代,模型在推理基准上表现出超强性能。
6. 挑战与局限
- 问题:尽管推理能力强大,DeepSeek-R1-Zero 存在可读性差(poor readability)和语言混合(language mixing)等问题。这些问题推动了后续 DeepSeek-R1 的开发,引入冷启动数据和多阶段训练来解决。
总结
DeepSeek-R1-Zero 的训练过程是一个从基础模型 DeepSeek-V3-Base 开始,基于 GRPO 算法的大型纯强化学习实验。通过规则驱动的准确性和格式奖励,结合简单的训练模板,模型在不依赖监督数据的情况下,逐步自我进化,发展出强大的推理能力(如长链推理、反思等)。其性能在 AIME 等基准上显著提升,但也暴露出可读性和语言一致性方面的不足,为后续改进提供了方向。
DeepSeek-R1-Zero训练代码实现
为 DeepSeek-R1-Zero 的训练过程,尤其是奖励模型(Reward Modeling)部分,提供代码实现。这部分是基于规则的奖励系统,包括准确性奖励(Accuracy Rewards)和格式奖励(Format Rewards),并且避免使用神经奖励模型。这里提供一个简化的 PyTorch 实现,重点体现奖励模型的设计和集成。
设计思路
- 奖励模型的目标:
- 准确性奖励:检查模型输出是否正确,例如数学问题的最终答案是否与标准答案匹配,或 LeetCode 问题的代码是否通过测试用例。
- 格式奖励:确保输出符合
<think>推理过程</think><answer>答案</answer>
的结构。
- 实现方式:
- 使用规则逻辑(而不是神经网络)计算奖励。
- 将奖励融入 GRPO 的训练循环,用于计算优势函数 ( A i A_i Ai ) 和优化策略。
- 代码结构:
- 定义一个奖励计算函数,处理准确性和格式两部分。
- 将其嵌入 GRPO 的损失计算中。
以下是代码实现:
代码实现
import torch
import torch.nn as nn
import re
from typing import List, Tuple
# 假设的策略网络(Actor),输出 token 序列的分布
class PolicyNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim, vocab_size):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, vocab_size) # 输出每个 token 的 logits
)
def forward(self, states):
return self.net(states) # 返回 logits
# 奖励模型:基于规则计算准确性和格式奖励
def compute_reward(output: str, ground_truth: str = None, problem_type: str = "math") -> float:
"""
计算单个输出的奖励。
- output: 模型生成的文本
- ground_truth: 标准答案(数学问题提供,LeetCode 可为空)
- problem_type: "math" 或 "leetcode"
"""
# 初始化奖励
total_reward = 0.0
# 1. 格式奖励:检查是否符合 <think></think><answer></answer> 结构
think_pattern = r"<think>.*?</think>"
answer_pattern = r"<answer>.*?</answer>"
has_think = bool(re.search(think_pattern, output))
has_answer = bool(re.search(answer_pattern, output))
format_reward = 1.0 if (has_think and has_answer) else 0.0
total_reward += format_reward # 格式奖励权重为 1.0
# 2. 准确性奖励
if problem_type == "math" and ground_truth:
# 提取 <answer> 中的答案
answer_match = re.search(r"<answer>(.*?)</answer>", output)
if answer_match:
predicted_answer = answer_match.group(1).strip()
# 检查是否与标准答案匹配(假设答案是数字或简单字符串)
accuracy_reward = 1.0 if predicted_answer == ground_truth else 0.0
total_reward += accuracy_reward # 准确性奖励权重为 1.0
else:
total_reward += 0.0 # 无答案则准确性奖励为 0
elif problem_type == "leetcode":
# 模拟 LeetCode 测试用例验证(这里简化为编译器检查逻辑)
# 假设 output 是代码,调用外部函数验证(伪代码)
from compiler import run_tests # 假设的外部编译器接口
test_result = run_tests(output) # 返回 True/False 表示是否通过测试
accuracy_reward = 1.0 if test_result else 0.0
total_reward += accuracy_reward
return total_reward
# GRPO 训练核心逻辑
def compute_grpo_loss(
policy: PolicyNetwork,
old_policy: PolicyNetwork,
ref_policy: PolicyNetwork,
states: torch.Tensor,
outputs: List[str],
rewards: List[float],
epsilon: float = 0.2,
beta: float = 0.01
) -> torch.Tensor:
"""
计算 GRPO 损失。
- policy: 当前策略网络
- old_policy: 旧策略网络(用于采样和比率计算)
- ref_policy: 参考策略(用于 KL 散度)
- states: 输入状态(问题描述的 embedding)
- outputs: 模型生成的输出序列
- rewards: 每个输出的奖励
"""
G = len(outputs) # 组大小
assert G > 0, "Group size must be positive"
# 计算新旧策略的概率比和 KL 散度
policy_logits = policy(states) # [batch_size, seq_len, vocab_size]
old_logits = old_policy(states).detach() # 旧策略不更新梯度
ref_logits = ref_policy(states).detach() # 参考策略固定
# 假设 outputs 已转换为 token IDs,这里简化为对数概率的近似
# 在实际中需要序列化处理,这里用占位符
log_probs = torch.softmax(policy_logits, dim=-1).log() # 模拟 log π_θ(o|q)
old_log_probs = torch.softmax(old_logits, dim=-1).log() # log π_θ_old(o|q)
ref_log_probs = torch.softmax(ref_logits, dim=-1).log() # log π_ref(o|q)
# 计算概率比 r_t(θ)
ratios = torch.exp(log_probs - old_log_probs) # π_θ(o|q) / π_θ_old(o|q)
# 计算优势 A_i
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
mean_reward = rewards_tensor.mean()
std_reward = rewards_tensor.std() + 1e-6 # 避免除零
advantages = (rewards_tensor - mean_reward) / std_reward
# 剪切损失项
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
clipped_loss = torch.min(surr1, surr2).mean()
# KL 散度正则化项
kl_div = (ref_log_probs - log_probs).mean() # 简化的 KL 散度计算
# 总损失
loss = clipped_loss - beta * kl_div
return -loss # 最大化 J_GRPO 等价于最小化 -loss
# 训练循环
def train_deepseek_r1_zero(
policy: PolicyNetwork,
optimizer: torch.optim.Optimizer,
problems: List[Tuple[str, str]], # (问题描述, 标准答案)
num_epochs: int = 100,
group_size: int = 16
):
ref_policy = PolicyNetwork(...) # 初始化参考策略(固定)
old_policy = PolicyNetwork(...) # 初始化旧策略
old_policy.load_state_dict(policy.state_dict()) # 同步初始参数
for epoch in range(num_epochs):
for problem, ground_truth in problems:
# 准备输入状态(假设问题已转为 embedding)
state = torch.tensor(...) # 简化为占位符
# 从旧策略采样一组输出
outputs = []
for _ in range(group_size):
with torch.no_grad():
logits = old_policy(state)
# 假设通过采样生成序列(简化处理)
output = "sampled_output" # 实际需解码为文本
outputs.append(output)
# 计算奖励
rewards = [compute_reward(output, ground_truth, "math") for output in outputs]
# 计算 GRPO 损失并更新策略
optimizer.zero_grad()
loss = compute_grpo_loss(policy, old_policy, ref_policy, state, outputs, rewards)
loss.backward()
optimizer.step()
# 更新旧策略
old_policy.load_state_dict(policy.state_dict())
print(f"Epoch {epoch}, Loss: {loss.item()}")
# 示例调用
policy = PolicyNetwork(input_dim=512, hidden_dim=1024, vocab_size=30000)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
problems = [("Solve 2+2", "4"), ("Solve 3*3", "9")]
train_deepseek_r1_zero(policy, optimizer, problems)
代码说明
1. 奖励模型(compute_reward
)
- 准确性奖励:
- 对于数学问题,使用正则表达式提取
<answer>
中的答案,与ground_truth
比较。如果匹配,奖励为 1.0,否则为 0.0。 - 对于 LeetCode 问题,假设有一个外部函数
run_tests
(这里是伪代码),检查代码是否通过测试用例。
- 对于数学问题,使用正则表达式提取
- 格式奖励:
- 使用正则表达式检查
<think>
和<answer>
标签是否存在,符合要求则奖励 1.0,否则 0.0。
- 使用正则表达式检查
- 总奖励:
- 简单相加准确性和格式奖励(权重均为 1.0,可根据需要调整)。
- 避免神经奖励模型:
- 实现完全基于规则逻辑,不依赖任何神经网络,确保简单性和可控性,避免 reward hacking。
2. GRPO 损失(compute_grpo_loss
)
- 输入:
outputs
是模型生成的文本序列,rewards
是对应奖励。
- 优势计算:
- 根据公式 ( A i = r i − mean std A_i = \frac{r_i - \text{mean}}{\text{std}} Ai=stdri−mean ),直接用组内奖励统计计算。
- 剪切和 KL:
- 保留 PPO 的剪切机制,同时加入 ( D K L ( π θ ∣ ∣ π ref ) D_{KL}(\pi_\theta || \pi_{\text{ref}}) DKL(πθ∣∣πref) )。
- 简化假设:
- 这里假设
log_probs
从 logits 计算,实际语言模型需处理序列生成(例如用 transformers 的generate
方法)。
- 这里假设
3. 训练循环(train_deepseek_r1_zero
)
- 采样:从旧策略生成一组输出(这里简化为占位符,实际需实现 token-by-token 生成)。
- 奖励计算:调用
compute_reward
为每个输出评分。 - 优化:计算 GRPO 损失并更新策略。
注意事项
- 简化之处:
- 代码中省略了序列生成的细节(例如如何将 logits 转为文本)。实际中需结合 transformers 库,使用
model.generate()
并计算每步的概率。 state
的生成(问题转为 embedding)未实现,可用预训练模型(如 BERT)编码问题。
- 代码中省略了序列生成的细节(例如如何将 logits 转为文本)。实际中需结合 transformers 库,使用
- 可扩展性:
- 可以为不同问题类型添加更多规则(例如科学推理、逻辑题)。
- 调整奖励权重(例如格式奖励占 0.3,准确性占 0.7)。
- 性能优化:
- 批量处理多组样本以加速训练。
- 使用分布式计算支持大规模 RL。
如何体现原论文中的设计?
- 规则奖励:
compute_reward
严格遵循文档描述,使用正则表达式和逻辑判断实现准确性和格式检查。 - 避免神经模型:奖励计算不涉及任何神经网络,完全基于规则,符合文档避免 reward hacking 的目标。
- GRPO 集成:奖励值直接用于优势计算,驱动策略优化,与 DeepSeek-R1-Zero 的训练过程一致。
PPO和GRPO的区别
PPO 中的 Critic Model(Value Model)的作用
在 PPO(以及许多 Actor-Critic 类型的强化学习算法)中,critic model(通常称为价值模型)是一个独立的神经网络,用于估计状态的价值函数 ( V ( s ) V(s) V(s) )。它的主要作用是为策略优化提供一个基线(baseline),从而计算优势函数 ( A ( s , a ) A(s, a) A(s,a) ),帮助减少策略梯度估计的方差,提升训练的稳定性。
1. Critic Model 如何使用?
- 价值函数估计:Critic model 的任务是预测给定状态 ( s s s ) 下的期望累计回报,即 ( V π ( s ) = E [ ∑ t γ t r t ∣ s 0 = s ] V^\pi(s) = \mathbb{E} [ \sum_t \gamma^t r_t | s_0 = s ] Vπ(s)=E[∑tγtrt∣s0=s] ),其中 ( γ \gamma γ ) 是折扣因子,( r t r_t rt ) 是奖励。
- 与策略模型的关系:在 PPO 中,策略模型(Actor)负责输出动作分布 ( π ( a ∣ s ) \pi(a|s) π(a∣s) ),而 Critic model 提供价值估计,二者共同优化。
- 训练数据:Critic model 通常通过时序差分(TD)学习或蒙特卡洛回报来更新,使用实际获得的回报 ( R R R )(如 GAE,Generalized Advantage Estimation)作为监督信号。
2. Critic Model 在计算优势(Advantage)中的作用
是的,Critic model 的核心作用之一就是用于计算优势函数 ( A ( s , a ) A(s, a) A(s,a) )。优势函数衡量的是在状态 ( s s s ) 下选择动作 ( a a a ) 相比于平均情况(即状态价值 ( V ( s ) V(s) V(s) ))的额外收益。PPO 中常用的计算方式是:
A ( s , a ) = Q ( s , a ) − V ( s ) A(s, a) = Q(s, a) - V(s) A(s,a)=Q(s,a)−V(s)
其中:
- ( Q ( s , a ) Q(s, a) Q(s,a) ) 是动作价值函数,表示在状态 ( s s s ) 下执行动作 ( a a a ) 后的期望回报;
- ( V ( s ) V(s) V(s) ) 是状态价值函数,由 Critic model 预测。
在实践中,直接计算 ( Q ( s , a ) Q(s, a) Q(s,a) ) 需要完整的环境动态,而这通常不可行。因此,PPO 通常使用 时序差分误差 或 广义优势估计(GAE) 来近似 ( A ( s , a ) A(s, a) A(s,a) ):
- 简单 TD 误差:
A ( s , a ) = r + γ V ( s ′ ) − V ( s ) A(s, a) = r + \gamma V(s') - V(s) A(s,a)=r+γV(s′)−V(s)
其中 ( r r r ) 是即时奖励,( s ′ s' s′ ) 是下一状态。 - GAE(更常用):
A t = ∑ l = 0 ∞ ( γ λ ) l δ t + l , 其中 δ t = r t + γ V ( s t + 1 ) − V ( s t ) A_t = \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l}, \quad \text{其中} \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) At=l=0∑∞(γλ)lδt+l,其中δt=rt+γV(st+1)−V(st)
这里 ( λ \lambda λ ) 是 GAE 的超参数,用于平衡偏差和方差。
Critic model 通过输出 ( V ( s ) V(s) V(s) ) 和 ( V ( s ′ ) V(s') V(s′) ),为这些计算提供基线,使得优势函数能够反映动作的相对优劣。
3. PPO 的损失函数中如何体现?
PPO 的损失函数由两部分组成:策略损失(policy loss) 和 价值损失(value loss),分别优化 Actor 和 Critic,加上一个可选的熵正则化项。完整的 PPO 损失函数如下:
L P P O ( θ , ϕ ) = L p o l i c y ( θ ) + c 1 L v a l u e ( ϕ ) − c 2 S [ π θ ] ( s ) L^{PPO}(\theta, \phi) = L^{policy}(\theta) + c_1 L^{value}(\phi) - c_2 S[\pi_\theta](s) LPPO(θ,ϕ)=Lpolicy(θ)+c1Lvalue(ϕ)−c2S[πθ](s)
- ( θ \theta θ ) 和 ( ϕ \phi ϕ ):分别表示策略网络(Actor)和价值网络(Critic)的参数。
- ( c 1 , c 2 c_1, c_2 c1,c2 ):超参数,平衡各部分的权重。
- ( S [ π θ ] ( s ) S[\pi_\theta](s) S[πθ](s) ):策略的熵,用于鼓励探索。
(1) 策略损失 ( L p o l i c y ( θ ) L^{policy}(\theta) Lpolicy(θ) )
策略损失使用剪切(clipped)比率来限制策略更新幅度,避免过大的步长:
L
p
o
l
i
c
y
(
θ
)
=
E
[
min
(
r
t
(
θ
)
A
^
t
,
clip
(
r
t
(
θ
)
,
1
−
ε
,
1
+
ε
)
A
^
t
)
]
L^{policy}(\theta) = \mathbb{E} \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\varepsilon, 1+\varepsilon) \hat{A}_t \right) \right]
Lpolicy(θ)=E[min(rt(θ)A^t,clip(rt(θ),1−ε,1+ε)A^t)]
其中:
- ( r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(at∣st)πθ(at∣st) ) 是新旧策略的概率比;
- ( A ^ t \hat{A}_t A^t ) 是优势估计,由 Critic model 提供(如上所述,通过 GAE 或 TD 误差计算);
- ( ε \varepsilon ε ) 是剪切范围的超参数。
体现 Critic 的地方:这里的 ( A ^ t \hat{A}_t A^t ) 直接依赖 Critic model 的输出 ( V ( s ) V(s) V(s) ),Critic 的准确性影响优势估计的质量,从而影响策略优化的方向和稳定性。
(2) 价值损失 ( L v a l u e ( ϕ ) L^{value}(\phi) Lvalue(ϕ) )
Critic model 的训练目标是最小化预测价值 (
V
ϕ
(
s
)
V_\phi(s)
Vϕ(s) ) 与实际回报(目标价值)的误差:
L
v
a
l
u
e
(
ϕ
)
=
E
[
(
V
ϕ
(
s
t
)
−
R
t
)
2
]
L^{value}(\phi) = \mathbb{E} \left[ (V_\phi(s_t) - R_t)^2 \right]
Lvalue(ϕ)=E[(Vϕ(st)−Rt)2]
其中:
- ( R t R_t Rt ) 是目标回报,可以是蒙特卡洛回报或 GAE 计算的回报;
- 这部分损失直接优化 Critic model,使其更好地估计状态价值。
(3) 熵项(可选)
熵项 ( − c 2 S [ π θ ] ( s ) -c_2 S[\pi_\theta](s) −c2S[πθ](s) ) 与 Critic 无关,用于防止策略过早收敛到局部最优。
总结 PPO 中 Critic 的体现
- 计算优势:Critic model 输出 ( V ( s ) V(s) V(s) ),用于计算 ( A ( s , a ) A(s, a) A(s,a) ),这是策略梯度更新的核心信号。
- 独立优化:Critic 有自己的损失函数 ( L v a l u e L^{value} Lvalue ),通过监督学习方式与回报对齐。
- 整体作用:Critic 减少了策略梯度的方差,使得 PPO 能够在高维动作空间中更稳定地训练。
GRPO 与 PPO 的区别:为什么不用 Critic Model?
根据文档(第 2.2.1 节),GRPO(Group Relative Policy Optimization)放弃了 PPO 中的 Critic model,而是通过组评分(group scores)来估计基线。具体差异如下:
1. GRPO 的优势计算
- 无 Critic:GRPO 不依赖独立的 Critic model 来预测 (
V
(
s
)
V(s)
V(s) )。相反,它通过对一组输出 (
{
o
1
,
o
2
,
…
,
o
G
}
\{o_1, o_2, \ldots, o_G\}
{o1,o2,…,oG} ) 的奖励 (
{
r
1
,
r
2
,
…
,
r
G
}
\{r_1, r_2, \ldots, r_G\}
{r1,r2,…,rG} ) 进行统计计算来估计优势:
A i = r i − mean ( { r 1 , r 2 , ⋯ , r G } ) std ( { r 1 , r 2 , ⋯ , r G } ) A_i = \frac{r_i - \text{mean}(\{r_1, r_2, \cdots, r_G\})}{\text{std}(\{r_1, r_2, \cdots, r_G\})} Ai=std({r1,r2,⋯,rG})ri−mean({r1,r2,⋯,rG})
这里,基线是组内奖励的均值,方差归一化用于标准化。 - 优点:这种方法避免了训练一个单独的 Critic model,减少了计算开销和参数量,尤其在大规模 RL 中(如 DeepSeek-R1-Zero 的训练)非常重要。
2. GRPO 的损失函数
GRPO 的损失函数与 PPO 类似,但基线由组内统计替代:
J
G
R
P
O
(
θ
)
=
E
[
1
G
∑
i
=
1
G
(
min
(
π
θ
(
o
i
∣
q
)
π
θ
old
(
o
i
∣
q
)
A
i
,
clip
(
π
θ
(
o
i
∣
q
)
π
θ
old
(
o
i
∣
q
)
,
1
−
ε
,
1
+
ε
)
A
i
)
−
β
D
K
L
(
π
θ
∥
π
ref
)
)
]
\mathcal{J}_{GRPO}(\theta) = \mathbb{E} \left[ \frac{1}{G} \sum_{i=1}^G \left( \min \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip} \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\varepsilon, 1+\varepsilon \right) A_i \right) - \beta \mathbb{D}_{KL}(\pi_\theta \| \pi_{\text{ref}}) \right) \right]
JGRPO(θ)=E[G1i=1∑G(min(πθold(oi∣q)πθ(oi∣q)Ai,clip(πθold(oi∣q)πθ(oi∣q),1−ε,1+ε)Ai)−βDKL(πθ∥πref))]
- 无 ( L v a l u e L^{value} Lvalue ):由于没有 Critic,GRPO 不需要额外的价值损失项。
- 依赖采样:优势 ( A i A_i Ai ) 完全基于采样的奖励统计,而不是模型预测。
3. 为什么 GRPO 不用 Critic?
- 降低成本:文档明确提到,Critic model 通常与策略模型同等规模,训练它会显著增加计算负担。GRPO 通过组采样直接估计基线,简化了流程。
- 适用场景:在 DeepSeek-R1-Zero 的训练中,奖励(如准确性奖励和格式奖励)是基于规则的,易于计算,不需要复杂的价值估计,适合用统计方法替代 Critic。
PPO 和 GRPO 的对比
特性 | PPO | GRPO |
---|---|---|
Critic Model | 有,独立的价值网络估计 ( V ( s ) V(s) V(s) ) | 无,使用组奖励统计作为基线 |
优势计算 | ( A = r + γ V ( s ′ ) − V ( s ) A = r + \gamma V(s') - V(s) A=r+γV(s′)−V(s) ) 或 GAE | ( A i = r i − mean std A_i = \frac{r_i - \text{mean}}{\text{std}} Ai=stdri−mean ) |
损失函数 | 包含策略损失 + 价值损失 + 熵项 | 仅策略损失 + KL 正则化,无价值损失 |
计算复杂度 | 较高(两网络训练) | 较低(单网络 + 采样统计) |
总结
- PPO 中的 value model 是如何使用的?
- Value model(Critic)用于预测状态价值 ( V ( s ) V(s) V(s) ),作为基线计算优势 ( A ( s , a ) A(s, a) A(s,a) ),并通过 ( L v a l u e L^{value} Lvalue ) 单独优化。
- 是用在算优势那部分吗?
- 是的,Critic 的核心作用是为优势函数提供 ( V ( s ) V(s) V(s) ) 和 ( V ( s ′ ) V(s') V(s′) ),如 ( A = r + γ V ( s ′ ) − V ( s ) A = r + \gamma V(s') - V(s) A=r+γV(s′)−V(s) ) 或 GAE。
- 在 PPO 的 loss 中是怎么体现的?
- 体现在两处:(1) 策略损失中的 ( A ^ t \hat{A}_t A^t ) 依赖 Critic 的输出;(2) 价值损失 ( L v a l u e L^{value} Lvalue ) 直接优化 Critic 的预测。
PPO 中的 KL 散度:在哪里体现?
在 PPO 中,KL 散度并不是直接作为一个独立的损失项加到最终的损失函数中,而是通过策略更新时的约束机制隐式地起作用。PPO 的核心思想是通过限制新旧策略之间的偏差来保证训练的稳定性,而 KL 散度正是用来衡量这种偏差的工具。具体来说,PPO 使用两种方式来实现这一约束:
1. Clipped Surrogate Objective(剪切代理目标)
PPO 的标准形式(也是最常见的形式)使用剪切(clipping)来限制新旧策略之间的变化,而不是直接在损失函数中添加 KL 散度项。其损失函数如下:
L p o l i c y ( θ ) = E [ min ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ε , 1 + ε ) A ^ t ) ] L^{policy}(\theta) = \mathbb{E} \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\varepsilon, 1+\varepsilon) \hat{A}_t \right) \right] Lpolicy(θ)=E[min(rt(θ)A^t,clip(rt(θ),1−ε,1+ε)A^t)]
其中:
- ( r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(at∣st)πθ(at∣st) ) 是新旧策略的概率比;
- ( ε \varepsilon ε ) 是剪切范围的超参数(通常为 0.1 或 0.2)。
KL 散度的隐式体现:
- ( r t ( θ ) r_t(\theta) rt(θ) ) 表示新策略 ( π θ \pi_\theta πθ ) 和旧策略 ( π θ old \pi_{\theta_{\text{old}}} πθold ) 在动作概率上的相对变化。当 ( r t ( θ ) r_t(\theta) rt(θ) ) 偏离 1 过多时(即新旧策略差异过大),剪切操作会限制更新幅度。
- 这种剪切本质上是对新旧策略分布差异的一种近似控制,而 KL 散度正是衡量两个概率分布差异的数学工具。虽然公式中没有显式写出 ( D K L ( π θ ∣ ∣ π θ old ) D_{KL}(\pi_\theta || \pi_{\theta_{\text{old}}}) DKL(πθ∣∣πθold) ),但 ( r t ( θ ) r_t(\theta) rt(θ) ) 的剪切间接限制了 KL 散度的大小。
- 在 PPO 的原始论文(Schulman et al., 2017)中,作者提到剪切目标是对信任区域(trust region)的一种简化,而信任区域通常由 KL 散度定义。
2. 自适应 KL 惩罚形式(Adaptive KL Penalty,较少使用)
PPO 的另一种变体(在论文中也有提及,但实践中不常见)通过显式添加 KL 散度惩罚项来替代剪切机制。这种形式的损失函数如下:
L p o l i c y ( θ ) = E [ r t ( θ ) A ^ t − β D K L ( π θ old ∣ ∣ π θ ) ] L^{policy}(\theta) = \mathbb{E} \left[ r_t(\theta) \hat{A}_t - \beta D_{KL}(\pi_{\theta_{\text{old}}} || \pi_\theta) \right] Lpolicy(θ)=E[rt(θ)A^t−βDKL(πθold∣∣πθ)]
- (
β
\beta
β ) 是一个动态调整的惩罚系数,根据实际的 KL 散度大小自适应变化:
- 如果 ( D K L D_{KL} DKL ) 超过目标值(例如 0.01),则增大 ( β \beta β );
- 如果 ( D K L D_{KL} DKL ) 小于目标值,则减小 ( β \beta β )。
这种形式直接使用了 KL 散度,但在实践中,剪切形式(clipped objective)更流行,因为它更简单且不需要手动调整 ( β \beta β )。
为什么标准 PPO 损失中没有显式 KL 项?
- 标准 PPO(剪切版本)选择用剪切机制替代显式 KL 散度惩罚,因为:
- 计算简便:剪切只需计算概率比 ( r t ( θ ) r_t(\theta) rt(θ) ),而 KL 散度需要对整个分布求和或积分,计算成本更高。
- 稳定性:剪切提供了一个硬约束,效果直观且易于实现,避免了 ( β \beta β ) 自适应调节的不确定性。
- 因此,尽管 KL 散度的概念在 PPO 的设计中至关重要(限制策略更新步长),但它通过剪切机制被“隐藏”在了实现中,而不是显式出现在损失公式里。
GRPO 中的 KL 散度:显式体现
相比之下,GRPO(Group Relative Policy Optimization)在损失函数中显式地包含了 KL 散度项,如你提供的公式:
J G R P O ( θ ) = E [ 1 G ∑ i = 1 G ( min ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) A i , clip ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) , 1 − ε , 1 + ε ) A i ) − β D K L ( π θ ∣ ∣ π ref ) ) ] \mathcal{J}_{GRPO}(\theta) = \mathbb{E} \left[ \frac{1}{G} \sum_{i=1}^G \left( \min \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip} \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\varepsilon, 1+\varepsilon \right) A_i \right) - \beta D_{KL}(\pi_\theta || \pi_{\text{ref}}) \right) \right] JGRPO(θ)=E[G1i=1∑G(min(πθold(oi∣q)πθ(oi∣q)Ai,clip(πθold(oi∣q)πθ(oi∣q),1−ε,1+ε)Ai)−βDKL(πθ∣∣πref))]
- 作用:这里的 ( D K L ( π θ ∣ ∣ π ref ) D_{KL}(\pi_\theta || \pi_{\text{ref}}) DKL(πθ∣∣πref) ) 是一个正则化项,限制新策略 ( π θ \pi_\theta πθ ) 偏离参考策略 ( π ref \pi_{\text{ref}} πref )(而不是旧策略 ( π θ old \pi_{\theta_{\text{old}}} πθold ))的程度。
- 与 PPO 的区别:
- GRPO 同时保留了剪切机制(类似 PPO)和显式 KL 项,提供了双重约束。
- ( π ref \pi_{\text{ref}} πref ) 通常是一个固定的参考分布(例如初始策略),而 PPO 只关注 ( π θ old \pi_{\theta_{\text{old}}} πθold )(上一步的策略)。
PPO 中的 KL 散度代码实现
以下是用 PyTorch 实现的 PPO(剪切版本)核心部分的伪代码,展示 KL 散度的隐式约束如何通过剪切实现:
import torch
import torch.nn as nn
# 假设策略网络和价值网络已经定义
class ActorCritic(nn.Module):
def __init__(self):
super().__init__()
self.actor = nn.Sequential(...) # 输出动作分布参数(如均值和方差)
self.critic = nn.Sequential(...) # 输出状态价值 V(s)
# 计算 PPO 损失
def compute_ppo_loss(actor_critic, states, actions, old_log_probs, advantages, returns, epsilon=0.2):
# 前向传播:获取新策略的 log_prob 和价值估计
dist = actor_critic.actor(states) # 假设输出分布(如正态分布)
new_log_probs = dist.log_prob(actions)
values = actor_critic.critic(states)
# 计算概率比 r_t(θ)
ratios = torch.exp(new_log_probs - old_log_probs) # r_t = π_θ(a|s) / π_θ_old(a|s)
# 策略损失(clipped objective)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# 价值损失
value_loss = ((values - returns) ** 2).mean()
# 总损失(无显式 KL 项)
total_loss = policy_loss + 0.5 * value_loss # 0.5 是超参数 c1
return total_loss
# 训练循环(伪代码)
def train_ppo(actor_critic, optimizer, data):
states, actions, old_log_probs, advantages, returns = data
optimizer.zero_grad()
loss = compute_ppo_loss(actor_critic, states, actions, old_log_probs, advantages, returns)
loss.backward()
optimizer.step()
# 可选:计算实际 KL 散度(仅用于监控,不影响损失)
def compute_kl_divergence(dist_new, dist_old):
return torch.distributions.kl_divergence(dist_new, dist_old).mean()
关键点解释:
- 剪切机制:
ratios
表示 ( r t ( θ ) r_t(\theta) rt(θ) ),通过torch.clamp
限制在 ( [ 1 − ε , 1 + ε ] [1-\varepsilon, 1+\varepsilon] [1−ε,1+ε] ) 范围内。- 这隐式地控制了新旧策略的 KL 散度,避免过大更新。
- 无显式 KL 项:
- 损失函数中没有直接计算 (
D
K
L
D_{KL}
DKL ),但可以通过
compute_kl_divergence
单独监控(例如打印日志)。
- 损失函数中没有直接计算 (
D
K
L
D_{KL}
DKL ),但可以通过
- 实现细节:
old_log_probs
是从旧策略采样的动作的对数概率,需要在新策略更新前保存。dist
是新策略输出的分布(如torch.distributions.Normal
)。
如果使用自适应 KL 惩罚形式:
以下是 PPO 的 KL 惩罚版本的伪代码:
def compute_ppo_kl_loss(actor_critic, states, actions, old_dist, advantages, beta=0.01):
dist = actor_critic.actor(states)
new_log_probs = dist.log_prob(actions)
ratios = torch.exp(new_log_probs - old_dist.log_prob(actions))
# 策略损失
policy_loss = -(ratios * advantages).mean()
# KL 散度惩罚
kl_div = torch.distributions.kl_divergence(dist, old_dist).mean()
total_loss = policy_loss + beta * kl_div
return total_loss
- 显式 KL:这里直接计算 ( D K L D_{KL} DKL ) 并加到损失中,( β \beta β ) 需要根据 KL 大小动态调整。
总结
- PPO 中的 KL 散度:
- 在标准剪切版本中,KL 散度通过 ( r t ( θ ) r_t(\theta) rt(θ) ) 的剪切隐式体现,不出现在损失公式中。
- 在 KL 惩罚版本中,KL 散度显式加到损失函数中,但实践中较少使用。
- GRPO 中的 KL 散度:
- 显式出现在损失公式中,作为正则化项约束新策略偏离参考策略。
- 代码实现:
- 剪切版本简单高效,只需计算概率比并剪切;
- KL 惩罚版本需要额外计算分布间的 KL 散度,复杂度稍高。
PPO 中的 Clip 为什么间接代替了 KL 散度?
在 PPO 中,剪切机制(clipped surrogate objective)是为了实现信任区域(trust region)优化的简单近似,而信任区域通常是通过 KL 散度来定义的:
- 信任区域的意义:强化学习中,策略更新需要在新旧策略之间保持一定的一致性,避免过大的步长导致性能崩溃。KL 散度 ( D K L ( π θ old ∣ ∣ π θ ) D_{KL}(\pi_{\theta_{\text{old}}} || \pi_\theta) DKL(πθold∣∣πθ) ) 是衡量新旧策略分布差异的自然选择。
- PPO 的剪切设计:
- PPO 使用概率比 ( r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(at∣st)πθ(at∣st) ) 的剪切(限制在 ( [ 1 − ε , 1 + ε ] [1-\varepsilon, 1+\varepsilon] [1−ε,1+ε] ))来近似控制新旧策略的偏差。
- 当 ( r t ( θ ) r_t(\theta) rt(θ)) 偏离 1 过多时(即新策略偏离旧策略过远),剪切会限制梯度更新幅度。这种限制本质上是对 KL 散度的间接约束,因为 ( r t ( θ ) r_t(\theta) rt(θ) ) 的变化与 KL 散度密切相关(例如,当分布差异小时,( r t ( θ ) ≈ 1 + D K L r_t(\theta) \approx 1 + D_{KL} rt(θ)≈1+DKL ) 的线性近似)。
- 为什么不用显式 KL?:
- 计算 KL 散度需要对整个分布求和或积分,成本较高,而剪切只需计算单点概率比,简单高效。
- PPO 的目标是提供一个易于实现的算法,剪切机制在实践中效果很好,且避免了手动调整 KL 惩罚系数(如 ( β \beta β ))的不确定性。
因此,PPO 的剪切机制是对信任区域的“代理”(surrogate),间接实现了 KL 散度的约束功能。
GRPO 为什么同时有 Clip 和 KL 散度?
GRPO 的损失函数如下:
J G R P O ( θ ) = E [ 1 G ∑ i = 1 G ( min ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) A i , clip ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) , 1 − ε , 1 + ε ) A i ) − β D K L ( π θ ∣ ∣ π ref ) ) ] \mathcal{J}_{GRPO}(\theta) = \mathbb{E} \left[ \frac{1}{G} \sum_{i=1}^G \left( \min \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip} \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\varepsilon, 1+\varepsilon \right) A_i \right) - \beta D_{KL}(\pi_\theta || \pi_{\text{ref}}) \right) \right] JGRPO(θ)=E[G1i=1∑G(min(πθold(oi∣q)πθ(oi∣q)Ai,clip(πθold(oi∣q)πθ(oi∣q),1−ε,1+ε)Ai)−βDKL(πθ∣∣πref))]
GRPO 同时保留剪切和显式 KL 散度的原因可以从以下几个方面理解:
1. 双重约束的不同目标
- Clip 的作用:
- 类似于 PPO,剪切限制了新策略 ( π θ \pi_\theta πθ ) 相对于旧策略 ( π θ old \pi_{\theta_{\text{old}}} πθold ) 的单步更新幅度,确保每次迭代不会偏离过远。这是局部信任区域的保障,关注的是“当前更新”的稳定性。
- KL 散度的作用:
- GRPO 中的 ( D K L ( π θ ∣ ∣ π ref ) D_{KL}(\pi_\theta || \pi_{\text{ref}}) DKL(πθ∣∣πref) ) 是相对于一个固定的参考策略 ( π ref \pi_{\text{ref}} πref )(通常是初始策略或某个基准策略)的全局约束。它防止新策略在整个训练过程中偏离初始分布太远,起到长期正则化的作用。
- 区别:
- Clip 关注的是短期的、迭代间的稳定性(( π θ \pi_\theta πθ ) vs. ( π θ old \pi_{\theta_{\text{old}}} πθold ))。
- KL 散度关注的是长期的、相对于初始状态的稳定性(( π θ \pi_\theta πθ ) vs. ( π ref \pi_{\text{ref}} πref ))。
2. GRPO 的设计背景
- 去掉 Critic 的影响:GRPO 放弃了 PPO 中的 Critic model,使用组内奖励统计来计算优势 ( A i A_i Ai )。这种设计减少了计算成本,但可能导致策略更新方向不够稳定(因为没有价值函数提供基线)。因此,GRPO 通过双重约束(clip + KL)来增强稳定性。
- 语言模型的复杂性:在 DeepSeek-R1-Zero 的训练场景中,输出空间(如自然语言 token 序列)比传统 RL(如游戏)的动作空间更复杂。显式 KL 散度可以防止模型生成过于离谱的输出(例如语言混合或不可读内容),这在文档中提到的 DeepSeek-R1-Zero 的局限性中有所体现。
3. 实际效果
- 互补性:剪切和 KL 散度的组合提供了更强的控制:
- 剪切确保每次更新不会过于激进;
- KL 散度防止模型在长时间训练后“漂移”到不可控的状态。
- 实验驱动:GRPO 的设计可能基于实验观察,发现单独使用剪切不足以应对大规模语言模型训练中的挑战(如 reward hacking 或分布崩溃),因此引入显式 KL 项作为额外保障。
PPO 可以同时拥有 Clip 和 KL 散度吗?
理论上,PPO 完全可以同时拥有剪切和显式 KL 散度,而且这种组合在某些变体中已经被提出或实验过。下面是分析和实现思路:
1. 可行性分析
- 理论依据:
- PPO 的剪切机制是对信任区域的近似,而显式 KL 散度是信任区域的直接表达。结合起来可以提供更细粒度的控制。
- 例如,可以用剪切限制单步更新,用 KL 散度正则化全局分布,类似于 GRPO 的思路。
- 潜在优势:
- 在高维或复杂任务中(如语言生成),仅靠剪切可能不足以防止策略分布的退化,添加 KL 散度可以增强鲁棒性。
- 潜在问题:
- 超参数调节:需要同时调整剪切范围 ( ε \varepsilon ε ) 和 KL 惩罚系数 ( β \beta β ),增加了调参难度。
- 计算开销:显式计算 KL 散度会增加计算负担,尤其是对于连续分布或高维离散分布。
2. PPO + Clip + KL 的损失函数
如果在 PPO 中加入显式 KL 散度,损失函数可能如下:
L P P O + K L ( θ ) = E [ min ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ε , 1 + ε ) A ^ t ) − β D K L ( π θ ∣ ∣ π ref ) ] + c 1 L v a l u e ( ϕ ) L^{PPO+KL}(\theta) = \mathbb{E} \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\varepsilon, 1+\varepsilon) \hat{A}_t \right) - \beta D_{KL}(\pi_\theta || \pi_{\text{ref}}) \right] + c_1 L^{value}(\phi) LPPO+KL(θ)=E[min(rt(θ)A^t,clip(rt(θ),1−ε,1+ε)A^t)−βDKL(πθ∣∣πref)]+c1Lvalue(ϕ)
- 与 GRPO 的区别:
- PPO 保留了 Critic model 和价值损失 ( L v a l u e L^{value} Lvalue ),而 GRPO 没有。
- ( π ref \pi_{\text{ref}} πref ) 可以选择为 ( π θ old \pi_{\theta_{\text{old}}} πθold )(动态参考旧策略)或固定的初始策略。
3. 代码实现示例
以下是用 PyTorch 实现的 PPO + Clip + KL 的伪代码:
import torch
import torch.nn as nn
class ActorCritic(nn.Module):
def __init__(self):
super().__init__()
self.actor = nn.Sequential(...) # 输出动作分布
self.critic = nn.Sequential(...) # 输出状态价值
def compute_ppo_with_kl_loss(actor_critic, states, actions, old_log_probs, advantages, returns, ref_dist, epsilon=0.2, beta=0.01):
# 新策略分布和价值
dist = actor_critic.actor(states)
new_log_probs = dist.log_prob(actions)
values = actor_critic.critic(states)
# 概率比和剪切损失
ratios = torch.exp(new_log_probs - old_log_probs)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# 价值损失
value_loss = ((values - returns) ** 2).mean()
# 显式 KL 散度(相对于参考分布)
kl_div = torch.distributions.kl_divergence(dist, ref_dist).mean()
# 总损失
total_loss = policy_loss + 0.5 * value_loss + beta * kl_div
return total_loss
# 训练循环
def train_ppo_with_kl(actor_critic, optimizer, data, ref_dist):
states, actions, old_log_probs, advantages, returns = data
optimizer.zero_grad()
loss = compute_ppo_with_kl_loss(actor_critic, states, actions, old_log_probs, advantages, returns, ref_dist)
loss.backward()
optimizer.step()
- 关键点:
ref_dist
是参考分布,可以是初始策略的输出分布(需预先保存)。kl_div
使用 PyTorch 的kl_divergence
函数计算,适用于常见分布(如正态分布或离散分布)。
4. PPO 不常用 Clip + KL 的原因
- 冗余性:剪切已经很好地实现了信任区域约束,添加 KL 散度可能收益不大,反而增加复杂度。
- 简单性优先:PPO 的设计目标是简单易用,剪切机制在大多数任务中足够稳定,显式 KL 散度显得多余。
- 任务依赖:在语言模型等复杂任务中,显式 KL 可能更有价值,但在传统 RL 任务(如 Atari 游戏)中,剪切已足够。
总结
- 为什么 GRPO 既有 Clip 又有 KL?
- GRPO 去掉了 Critic,依赖双重约束(剪切控制单步稳定性,KL 控制全局正则化)来应对复杂任务(如语言生成)的挑战。
- PPO 可以同时有两者吗?
- 可以,理论上可行且实现简单,但实践中很少这样做,因为剪切已足够,且增加 KL 散度会提高计算成本和调参难度。
- 区别的根源:
- GRPO 的设计适应了大规模语言模型训练的需求(无 Critic、高维输出),而 PPO 更通用,倾向于简单高效。
后记
2025年3月3日13点42分于上海,在grok3大模型辅助下完成。