全面解读PPO算法:结合DeepSpeed Chat实现分析, Critic Loss的设计

全面解读PPO算法:结合DeepSpeed Chat实现分析


1. 什么是PPO?

Proximal Policy Optimization (PPO) 是一种基于策略梯度的强化学习方法,属于 Actor-Critic 框架的改进算法。它的目标是稳定地优化策略,避免策略更新过于激进,同时保持训练效率。

在 PPO 中,主要包含两个核心模块:

  • Actor:负责学习策略 ( π θ ( a ∣ s ) \pi_\theta(a|s) πθ(as)),即选择某一动作的概率分布。
  • Critic:负责估计状态值 ( V ( s ) V(s) V(s)),为策略优化提供参考。

PPO 的设计核心是裁剪(Clipping)策略,确保策略更新幅度受控,从而提高训练的稳定性。
关于PPO的训练流程,可以参考笔者的另一篇博客:RLHF (PPO) 流程详解: Proximal Policy Optimization


2. PPO 的两个核心损失

PPO 的优化目标包括 Actor Loss(策略损失)和 Critic Loss(值函数损失)。我们将结合 DeepSpeed Chat 的实现,详细讲解这两部分的设计。


2.1 Actor Loss:策略裁剪目标

Actor 的目标是优化策略,使得它选择动作的概率与优势函数 ( A t A_t At) 成正比。优势函数表示当前动作的优劣程度,定义为:

A t = Q ( s t , a t ) − V ( s t ) A_t = Q(s_t, a_t) - V(s_t) At=Q(st,at)V(st)

PPO实际上采用的是GAE版本的优势函数,请参考笔者的另一篇博客:深入解析强化学习中的 Generalized Advantage Estimation (GAE)

为了防止策略更新过快,PPO 对策略更新的目标函数加入裁剪限制:

r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(atst)πθ(atst)

损失函数为:

L actor ( θ ) = E t [ min ⁡ ( r t ( θ ) ⋅ A t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) ⋅ A t ) ] \mathcal{L}_{\text{actor}}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \cdot A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot A_t \right) \right] Lactor(θ)=Et[min(rt(θ)At,clip(rt(θ),1ϵ,1+ϵ)At)]

其中:

  • ( r t ( θ ) r_t(\theta) rt(θ)):新策略与旧策略概率的比值。
  • ( ϵ \epsilon ϵ):裁剪范围,通常取 0.1 ~ 0.3。
  • ( clip ( ) \text{clip}() clip()):将比值限制在 ( [ 1 − ϵ , 1 + ϵ ] [1-\epsilon, 1+\epsilon] [1ϵ,1+ϵ]) 范围内。

PPO 的策略目标通过裁剪限制 ( r t ( θ ) r_t(\theta) rt(θ)),避免策略更新幅度过大。

2.2 Critic Loss:值函数裁剪目标

Critic 的目标是学习一个值函数 ( V ( s ) V(s) V(s)),使其接近实际的回报 ( R t R_t Rt)。通常使用均方误差 (MSE) 作为损失函数:

L critic ( ϕ ) = E t [ ( V ϕ ( s t ) − R t ) 2 ] \mathcal{L}_{\text{critic}}(\phi) = \mathbb{E}_t \left[ \left( V_\phi(s_t) - R_t \right)^2 \right] Lcritic(ϕ)=Et[(Vϕ(st)Rt)2]

在 DeepSpeed Chat 的实现中,Critic Loss 引入了裁剪机制,限制值函数的更新幅度:

V ϕ ( s t ) clipped = clip ( V ϕ ( s t ) , V ϕ ( s t ) old − ϵ , V ϕ ( s t ) old + ϵ ) V_\phi(s_t)^{\text{clipped}} = \text{clip}(V_\phi(s_t), V_\phi(s_t)^{\text{old}} - \epsilon, V_\phi(s_t)^{\text{old}} + \epsilon) Vϕ(st)clipped=clip(Vϕ(st),Vϕ(st)oldϵ,Vϕ(st)old+ϵ)

损失函数为:

L critic ( ϕ ) = 1 2 ⋅ E t [ max ⁡ ( ( V ϕ ( s t ) − R t ) 2 , ( V ϕ ( s t ) clipped − R t ) 2 ) ] \mathcal{L}_{\text{critic}}(\phi) = \frac{1}{2} \cdot \mathbb{E}_t \left[ \max \left( \left( V_\phi(s_t) - R_t \right)^2, \left( V_\phi(s_t)^{\text{clipped}} - R_t \right)^2 \right) \right] Lcritic(ϕ)=21Et[max((Vϕ(st)Rt)2,(Vϕ(st)clippedRt)2)]

这种设计能够限制值函数的剧烈变化,从而提高训练稳定性。


3. PPO 的实现:结合 DeepSpeed Chat

DeepSpeed Chat 中,PPO 的实现集中在以下几个核心部分。

3.1 Actor 和 Critic 损失的计算

train_rlhf 方法中,分别计算 Actor Loss 和 Critic Loss:
注:这段代码只是模拟DeepSpeed Chat,具体实现请看源代码。但是思路是一致的,这里为的是方便理解,进行的简化。

def train_rlhf(self, exp_data):
    # 提取输入数据
    logprobs, old_logprobs = exp_data["logprobs"], exp_data["old_logprobs"]
    values, old_values = exp_data["values"], exp_data["old_values"]
    rewards, returns = exp_data["rewards"], exp_data["returns"]
    advantages = returns - values #优势仅仅是模拟,具体实现请看下文中的源代码解析

    # 计算 Actor 损失
    # 和DeepSpeed Chat稍有区别,下文的源代码解析中有讲到
    log_ratio = (logprobs - old_logprobs) * exp_data["mask"]
    ratio = torch.exp(log_ratio)
    actor_loss1 = advantages * ratio
    actor_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
    actor_loss = -torch.sum(torch.min(actor_loss1, actor_loss2) * exp_data["mask"]) / exp_data["mask"].sum()

    # 计算 Critic 损失
    values_clipped = torch.clamp(
        values,
        old_values - self.cliprange_value,
        old_values + self.cliprange_value,
    )
    vf_loss1 = (values - returns) ** 2
    vf_loss2 = (values_clipped - returns) ** 2
    critic_loss = 0.5 * torch.sum(
        torch.max(vf_loss1, vf_loss2) * exp_data["mask"]
    ) / exp_data["mask"].sum()

    return actor_loss, critic_loss
  • Actor 损失:根据裁剪策略计算 ( L actor \mathcal{L}_{\text{actor}} Lactor)。
  • Critic 损失:通过裁剪值函数计算 ( L critic \mathcal{L}_{\text{critic}} Lcritic)。

实际上,原仓库对它们进行了封装,分别计算两个loss。
下面是 https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py封装的函数,具体解析请参考笔者的另一篇博客: 基于DeepSpeed Chat详解 PPO 算法中的actor_loss_fn及其核心参数

 def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
        ## policy gradient loss
        log_ratio = (logprobs - old_logprobs) * mask
        ratio = torch.exp(log_ratio)
        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                             1.0 + self.cliprange)
        pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
        return pg_loss

 def critic_loss_fn(self, values, old_values, returns, mask):
     ## value loss
     values_clipped = torch.clamp(
         values,
         old_values - self.cliprange_value,
         old_values + self.cliprange_value,
     )
     if self.compute_fp32_loss:
         values = values.float()
         values_clipped = values_clipped.float()
     vf_loss1 = (values - returns)**2
     vf_loss2 = (values_clipped - returns)**2
     vf_loss = 0.5 * torch.sum(
         torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
     return vf_loss

关于上文提到的优势计算,DeepSpeed Chat用到的是GAE版本的:
下面代码的解析请参考笔者的另一篇博客:深入理解 Generalized Advantage Estimation (GAE) 及其代码实现:以DeepSpeed-Chat中PPO算法使用为例

def get_advantages_and_returns(self, values, rewards, start):
        # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
        lastgaelam = 0
        advantages_reversed = []
        length = rewards.size()[-1]
        for t in reversed(range(start, length)):
            nextvalues = values[:, t + 1] if t < length - 1 else 0.0
            delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
            lastgaelam = delta + self.gamma * self.lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + values[:, start:]
        return advantages.detach(), returns

3.2 主训练循环

main.py 中,PPO 的训练循环如下:

for ppo_ep in range(args.ppo_epochs):
    for i, (exp_data, unsup_data) in enumerate(zip(exp_dataset, unsup_dataset)):
        # 训练 Actor 和 Critic
        actor_loss, critic_loss = trainer.train_rlhf(exp_data)

        # 记录损失
        actor_loss_sum += actor_loss.item()
        critic_loss_sum += critic_loss.item()
        average_reward += exp_data["rewards"].mean()
  • exp_data:经验数据,包含策略概率(logprobs)、状态值(values)、回报(returns)等。
  • train_rlhf():调用训练函数,返回 Actor 和 Critic 的损失。
  • actor_loss_sumcritic_loss_sum:累积损失用于记录训练进展。

3.3 熵的缺失

值得注意的是,DeepSpeed Chat 的实现没有显式地加入 熵正则项。熵正则项通常用于增加策略的随机性,从而提高探索性。其常见形式为:

L entropy = E t [ − ∑ a π ( a ∣ s t ) log ⁡ π ( a ∣ s t ) ] \mathcal{L}_{\text{entropy}} = \mathbb{E}_t \left[ -\sum_a \pi(a|s_t) \log \pi(a|s_t) \right] Lentropy=Et[aπ(ast)logπ(ast)]

在没有熵正则项的情况下,模型可能更快收敛到局部最优策略,而缺乏足够的探索。


4. PPO 的总损失函数

PPO 的总损失为 Actor Loss 和 Critic Loss 的加权和,通常还包括熵正则项:

L total = L actor + c 1 ⋅ L critic − c 2 ⋅ L entropy \mathcal{L}_{\text{total}} = \mathcal{L}_{\text{actor}} + c_1 \cdot \mathcal{L}_{\text{critic}} - c_2 \cdot \mathcal{L}_{\text{entropy}} Ltotal=Lactor+c1Lcriticc2Lentropy

  • ( c 1 , c 2 c_1, c_2 c1,c2):权重超参数,用于平衡各项损失。

在 DeepSpeed Chat 的实现中,熵正则项未被显式加入,因此总损失实际上为:

L total = L actor + c 1 ⋅ L critic \mathcal{L}_{\text{total}} = \mathcal{L}_{\text{actor}} + c_1 \cdot \mathcal{L}_{\text{critic}} Ltotal=Lactor+c1Lcritic


5. 总结与改进建议

5.1 DeepSpeed Chat 的 PPO 实现特点
  1. Actor 和 Critic 的裁剪机制:通过裁剪策略和值函数更新,保证了训练的稳定性。
  2. 简化实现:在损失函数中省略了熵正则项,从而简化了实现。
5.2 改进建议

为了增强模型的探索性,可以加入熵正则项,并将其权重 (c_2) 调整为适当的值。示例代码如下:

# 计算熵正则项
entropy = -torch.sum(exp_data["logprobs"] * torch.exp(exp_data["logprobs"]) * exp_data["mask"]) / exp_data["mask"].sum()

# 总损失函数
total_loss = actor_loss + self.vf_coef * critic_loss - self.entropy_coef * entropy

加入熵正则项后,模型可以在探索和利用之间实现更好的平衡。

5.3 总结

PPO 是一种高效且稳定的强化学习算法,其 Actor 和 Critic 的裁剪机制是其核心设计。DeepSpeed Chat 的实现体现了 PPO 的简化设计,同时也为研究者提供了扩展的空间,例如加入熵正则项、调整损失权重等。

为什么 Critic Loss 使用 max 函数?

在 Proximal Policy Optimization (PPO) 算法中,Critic Loss 使用了裁剪机制来限制值函数的更新幅度,以防止训练过程中的不稳定。损失函数形式如下:

L critic ( ϕ ) = 1 2 ⋅ E t [ max ⁡ ( ( V ϕ ( s t ) − R t ) 2 , ( V ϕ ( s t ) clipped − R t ) 2 ) ] \mathcal{L}_{\text{critic}}(\phi) = \frac{1}{2} \cdot \mathbb{E}_t \left[ \max \left( \left( V_\phi(s_t) - R_t \right)^2, \left( V_\phi(s_t)^{\text{clipped}} - R_t \right)^2 \right) \right] Lcritic(ϕ)=21Et[max((Vϕ(st)Rt)2,(Vϕ(st)clippedRt)2)]

其中:

  • ( V ϕ ( s t ) V_\phi(s_t) Vϕ(st) ):当前值网络对状态 ( s t s_t st ) 的估计。
  • ( R t R_t Rt ):目标回报(实际回报)。
  • ( V ϕ ( s t ) clipped V_\phi(s_t)^{\text{clipped}} Vϕ(st)clipped ):裁剪后的值函数:
    V ϕ ( s t ) clipped = clip ( V ϕ ( s t ) , V ϕ ( s t ) old − ϵ , V ϕ ( s t ) old + ϵ ) V_\phi(s_t)^{\text{clipped}} = \text{clip}\left( V_\phi(s_t), V_\phi(s_t)^{\text{old}} - \epsilon, V_\phi(s_t)^{\text{old}} + \epsilon \right) Vϕ(st)clipped=clip(Vϕ(st),Vϕ(st)oldϵ,Vϕ(st)old+ϵ)
    其中 ( V ϕ ( s t ) old V_\phi(s_t)^{\text{old}} Vϕ(st)old ) 是前一次的估计值,( ϵ \epsilon ϵ) 控制裁剪的范围。

为什么要用 max 函数?

max 函数的引入是为了在值函数更新时,限制过度估计带来的不稳定性,同时保证学习的效果。

  • 第一项 ( ( V ϕ ( s t ) − R t ) 2 (V_\phi(s_t) - R_t)^2 (Vϕ(st)Rt)2):表示当前值函数 ( V ϕ ( s t ) V_\phi(s_t) Vϕ(st) ) 和目标回报 ( R t R_t Rt ) 的误差。
  • 第二项 ( ( V ϕ ( s t ) clipped − R t ) 2 (V_\phi(s_t)^{\text{clipped}} - R_t)^2 (Vϕ(st)clippedRt)2):表示裁剪后的值函数 ( V ϕ ( s t ) clipped V_\phi(s_t)^{\text{clipped}} Vϕ(st)clipped ) 和目标回报 ( R t R_t Rt ) 的误差。

通过 max 操作,PPO 选择误差较大的那一项来计算损失,确保以下两点:

  1. 稳定训练过程:当 ( V ϕ ( s t ) V_\phi(s_t) Vϕ(st) ) 变化过大时,裁剪机制 ( V ϕ ( s t ) clipped V_\phi(s_t)^{\text{clipped}} Vϕ(st)clipped ) 会限制更新幅度。
  2. 避免过度惩罚:即使值函数的更新被裁剪,也不会因为裁剪而导致损失函数过于偏离目标。

数值模拟解析

假设:

  • ( V ϕ old = 10 V_\phi^{\text{old}} = 10 Vϕold=10 ):前一次值函数的估计值。
  • ( ϵ = 2 \epsilon = 2 ϵ=2 ):裁剪范围。
  • ( R t = 12 R_t = 12 Rt=12 ):目标回报。

我们分别计算以下三种情况:

  1. 当前值函数估计 ( V ϕ = 14 V_\phi = 14 Vϕ=14 )(过高估计)
  2. 当前值函数估计 ( V ϕ = 8 V_\phi = 8 Vϕ=8 )(过低估计)
  3. 当前值函数估计 ( V ϕ = 11 V_\phi = 11 Vϕ=11 )(合理范围内)

我们来计算每种情况下的 Critic Loss:


情况1:过高估计 ( V ϕ = 14 V_\phi = 14 Vϕ=14 )
  1. 裁剪前误差
    ( V ϕ − R t ) 2 = ( 14 − 12 ) 2 = 4 (V_\phi - R_t)^2 = (14 - 12)^2 = 4 (VϕRt)2=(1412)2=4

  2. 裁剪后的 ( V ϕ clipped V_\phi^{\text{clipped}} Vϕclipped ):
    V ϕ clipped = clip ( 14 , 10 − 2 , 10 + 2 ) = 12 V_\phi^{\text{clipped}} = \text{clip}(14, 10 - 2, 10 + 2) = 12 Vϕclipped=clip(14,102,10+2)=12
    裁剪后误差
    ( V ϕ clipped − R t ) 2 = ( 12 − 12 ) 2 = 0 (V_\phi^{\text{clipped}} - R_t)^2 = (12 - 12)^2 = 0 (VϕclippedRt)2=(1212)2=0

  3. Critic Loss
    L critic = 1 2 ⋅ max ⁡ ( 4 , 0 ) = 1 2 ⋅ 4 = 2 \mathcal{L}_{\text{critic}} = \frac{1}{2} \cdot \max(4, 0) = \frac{1}{2} \cdot 4 = 2 Lcritic=21max(4,0)=214=2


情况2:过低估计 ( V ϕ = 8 V_\phi = 8 Vϕ=8 )
  1. 裁剪前误差
    ( V ϕ − R t ) 2 = ( 8 − 12 ) 2 = 16 (V_\phi - R_t)^2 = (8 - 12)^2 = 16 (VϕRt)2=(812)2=16

  2. 裁剪后的 ( V ϕ clipped V_\phi^{\text{clipped}} Vϕclipped ):
    V ϕ clipped = clip ( 8 , 10 − 2 , 10 + 2 ) = 10 V_\phi^{\text{clipped}} = \text{clip}(8, 10 - 2, 10 + 2) = 10 Vϕclipped=clip(8,102,10+2)=10
    裁剪后误差
    ( V ϕ clipped − R t ) 2 = ( 10 − 12 ) 2 = 4 (V_\phi^{\text{clipped}} - R_t)^2 = (10 - 12)^2 = 4 (VϕclippedRt)2=(1012)2=4

  3. Critic Loss
    L critic = 1 2 ⋅ max ⁡ ( 16 , 4 ) = 1 2 ⋅ 16 = 8 \mathcal{L}_{\text{critic}} = \frac{1}{2} \cdot \max(16, 4) = \frac{1}{2} \cdot 16 = 8 Lcritic=21max(16,4)=2116=8


情况3:合理范围内 ( V ϕ = 11 V_\phi = 11 Vϕ=11 )
  1. 裁剪前误差
    ( V ϕ − R t ) 2 = ( 11 − 12 ) 2 = 1 (V_\phi - R_t)^2 = (11 - 12)^2 = 1 (VϕRt)2=(1112)2=1

  2. 裁剪后的 ( V ϕ clipped V_\phi^{\text{clipped}} Vϕclipped ):
    V ϕ clipped = clip ( 11 , 10 − 2 , 10 + 2 ) = 11 V_\phi^{\text{clipped}} = \text{clip}(11, 10 - 2, 10 + 2) = 11 Vϕclipped=clip(11,102,10+2)=11
    裁剪后误差
    ( V ϕ clipped − R t ) 2 = ( 11 − 12 ) 2 = 1 (V_\phi^{\text{clipped}} - R_t)^2 = (11 - 12)^2 = 1 (VϕclippedRt)2=(1112)2=1

  3. Critic Loss
    L critic = 1 2 ⋅ max ⁡ ( 1 , 1 ) = 1 2 ⋅ 1 = 0.5 \mathcal{L}_{\text{critic}} = \frac{1}{2} \cdot \max(1, 1) = \frac{1}{2} \cdot 1 = 0.5 Lcritic=21max(1,1)=211=0.5


总结

通过以上数值模拟,我们可以看到:

  1. 过高估计时:裁剪机制限制了 ( V ϕ V_\phi Vϕ ) 的更新幅度,Critic Loss 较小。
  2. 过低估计时:裁剪机制限制 ( V ϕ V_\phi Vϕ ) 过度下降,但仍允许一定程度的更新,Critic Loss 较大。
  3. 合理范围内时:Critic Loss 最小,表示值函数估计已经接近目标回报。

Critic Loss 使用 max 的意义

  • 避免值函数更新过大(通过裁剪限制)。
  • 同时保证训练仍然能够向正确的方向优化(选择较大误差)。
  • 提高训练的稳定性,减少梯度爆炸或值函数震荡的风险。

这种设计是 PPO 算法稳定性的重要来源。

后记

2024年12月14日15点14分于上海,在GPT4o大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值