深入理解 Generalized Advantage Estimation (GAE) 及其代码实现:以DeepSpeed-Chat中PPO算法使用为例

DeepSpeed-Chat中PPO算法中采用GAE方法,这里学习记录一下。

https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py#L44

def get_advantages_and_returns(self, values, rewards, start):
        # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
        lastgaelam = 0
        advantages_reversed = []
        length = rewards.size()[-1]
        for t in reversed(range(start, length)):
            nextvalues = values[:, t + 1] if t < length - 1 else 0.0
            delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
            lastgaelam = delta + self.gamma * self.lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + values[:, start:]
        return advantages.detach(), returns

下面我们将详细分析代码 get_advantages_and_returns 的实现,并结合理论解释和实际数值例子来讲解 Generalized Advantage Estimation (GAE) 的应用和计算过程。最终整理成一篇博客形式。


深入理解 Generalized Advantage Estimation (GAE) 及其代码实现

背景理论

在强化学习中,策略梯度方法通常需要估计每个动作的优势函数 ( A t A_t At )。优势函数衡量了一个动作比平均水平好的程度。公式如下:

A t = R t + γ V t + 1 − V t A_t = R_t + \gamma V_{t+1} - V_t At=Rt+γVt+1Vt

其中:

  • ( R t R_t Rt ):即时奖励
  • ( V t V_t Vt ):当前状态 ( S t S_t St ) 的价值估计
  • ( γ \gamma γ ):折扣因子,用于控制未来奖励的重要性

GAE 的核心思想是将当前时间步 ( t t t ) 的优势函数 ( A t A_t At ) 拆分成即时误差 ( δ t \delta_t δt ) 和未来累计误差的加权和。通过引入参数 ( λ \lambda λ ),GAE 可以控制优势估计对长期误差的关注程度。其递推公式为:

δ t = R t + γ V t + 1 − V t \delta_t = R_t + \gamma V_{t+1} - V_t δt=Rt+γVt+1Vt

A t = δ t + γ λ A t + 1 A_t = \delta_t + \gamma \lambda A_{t+1} At=δt+γλAt+1

参数 ( λ \lambda λ ) 控制了对未来优势的衰减。当 ( λ = 0 \lambda = 0 λ=0 ) 时,只考虑当前的即时误差;当 ( λ = 1 \lambda = 1 λ=1 ) 时,将所有未来的误差都考虑在内。

GAE 的优点是通过引入 ( λ \lambda λ ) 来平衡偏差与方差,从而提高策略梯度方法的稳定性和效率。


代码解析

代码部分
def get_advantages_and_returns(self, values, rewards, start):
    # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
    lastgaelam = 0
    advantages_reversed = []
    length = rewards.size()[-1]
    for t in reversed(range(start, length)):
        nextvalues = values[:, t + 1] if t < length - 1 else 0.0
        delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
        lastgaelam = delta + self.gamma * self.lam * lastgaelam
        advantages_reversed.append(lastgaelam)
    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = advantages + values[:, start:]
    return advantages.detach(), returns
核心逻辑分解
  1. 输入参数

    • values:每个时间步的状态价值估计 ( V t V_t Vt )(形状为 [batch, time])。
    • rewards:每个时间步的即时奖励 ( R t R_t Rt )。
    • start:从哪个时间步开始计算 GAE。
  2. 初始化

    • lastgaelam:存储当前时间步的 GAE 值。
    • advantages_reversed:存储逆序的 GAE 结果。
  3. 循环计算

    • 从最后一个时间步向前递推:
      • 计算当前时间步的即时误差 ( δ t \delta_t δt ):
        delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
        
      • 根据 GAE 公式递推更新 ( A t A_t At ):
        lastgaelam = delta + self.gamma * self.lam * lastgaelam
        
      • 将 ( A t A_t At ) 存储到逆序列表中。
  4. 优势与回报计算

    • 将逆序的 GAE 转为正序:
      advantages = torch.stack(advantages_reversed[::-1], dim=1)
      
    • 计算回报 ( R t R_t Rt ):
      returns = advantages + values[:, start:]
      
  5. 输出结果

    • 返回 GAE 的优势函数 ( A t A_t At ) 和回报 ( R t R_t Rt )。

实际数值例子

假设我们有以下输入数据:

values = torch.tensor([[0.5, 0.6, 0.7, 0.8]], dtype=torch.float32)  # 状态值
rewards = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32)  # 即时奖励
gamma = 0.99
lam = 0.95
计算过程
  1. 时间步 ( t = 3 t=3 t=3 )

    • ( δ 3 = R 3 + γ ⋅ 0 − V 3 = 1 + 0 − 0.8 = 0.2 \delta_3 = R_3 + \gamma \cdot 0 - V_3 = 1 + 0 - 0.8 = 0.2 δ3=R3+γ0V3=1+00.8=0.2 )
    • ( A 3 = δ 3 = 0.2 A_3 = \delta_3 = 0.2 A3=δ3=0.2 )
  2. 时间步 ( t = 2 t=2 t=2 )

    • ( δ 2 = R 2 + γ ⋅ V 3 − V 2 = 1 + 0.99 ⋅ 0.8 − 0.7 = 1.09 − 0.7 = 0.39 \delta_2 = R_2 + \gamma \cdot V_3 - V_2 = 1 + 0.99 \cdot 0.8 - 0.7 = 1.09 - 0.7 = 0.39 δ2=R2+γV3V2=1+0.990.80.7=1.090.7=0.39 )
    • ( A 2 = δ 2 + γ ⋅ λ ⋅ A 3 = 0.39 + 0.99 ⋅ 0.95 ⋅ 0.2 = 0.39 + 0.1881 = 0.5781 A_2 = \delta_2 + \gamma \cdot \lambda \cdot A_3 = 0.39 + 0.99 \cdot 0.95 \cdot 0.2 = 0.39 + 0.1881 = 0.5781 A2=δ2+γλA3=0.39+0.990.950.2=0.39+0.1881=0.5781 )
  3. 时间步 ( t = 1 t=1 t=1 )

    • ( δ 1 = R 1 + γ ⋅ V 2 − V 1 = 1 + 0.99 ⋅ 0.7 − 0.6 = 1.09 − 0.6 = 0.49 \delta_1 = R_1 + \gamma \cdot V_2 - V_1 = 1 + 0.99 \cdot 0.7 - 0.6 = 1.09 - 0.6 = 0.49 δ1=R1+γV2V1=1+0.990.70.6=1.090.6=0.49 )
    • ( A 1 = δ 1 + γ ⋅ λ ⋅ A 2 = 0.49 + 0.99 ⋅ 0.95 ⋅ 0.5781 = 0.49 + 0.5455 = 1.0355 A_1 = \delta_1 + \gamma \cdot \lambda \cdot A_2 = 0.49 + 0.99 \cdot 0.95 \cdot 0.5781 = 0.49 + 0.5455 = 1.0355 A1=δ1+γλA2=0.49+0.990.950.5781=0.49+0.5455=1.0355 )
  4. 时间步 ( t = 0 t=0 t=0 )

    • ( δ 0 = R 0 + γ ⋅ V 1 − V 0 = 1 + 0.99 ⋅ 0.6 − 0.5 = 1.094 − 0.5 = 0.594 \delta_0 = R_0 + \gamma \cdot V_1 - V_0 = 1 + 0.99 \cdot 0.6 - 0.5 = 1.094 - 0.5 = 0.594 δ0=R0+γV1V0=1+0.990.60.5=1.0940.5=0.594 )
    • ( A 0 = δ 0 + γ ⋅ λ ⋅ A 1 = 0.594 + 0.99 ⋅ 0.95 ⋅ 1.0355 = 0.594 + 0.9755 = 1.5695 A_0 = \delta_0 + \gamma \cdot \lambda \cdot A_1 = 0.594 + 0.99 \cdot 0.95 \cdot 1.0355 = 0.594 + 0.9755 = 1.5695 A0=δ0+γλA1=0.594+0.990.951.0355=0.594+0.9755=1.5695 )
结果
advantages = [1.5695, 1.0355, 0.5781, 0.2]
returns = [2.0695, 1.6355, 1.2781, 1.0]

总结

  1. GAE 优势:通过引入 ( λ \lambda λ ),GAE 平衡了偏差与方差,适应不同任务场景。
  2. 代码实现:递归方式精确计算优势函数 ( A t A_t At ),并利用向量化操作简化返回值计算。
  3. 应用场景:广泛用于 PPO 等策略优化算法中,以提高收敛速度和策略稳定性。

通过代码与理论结合,GAE 的计算过程变得更加清晰易懂!

【1】数值计算代码补充

基于前面提到的理论和手动计算的过程,我们使用 Python 实现数值计算的每一步,验证最终的优势函数 ( A t A_t At ) 和回报 ( R t R_t Rt )。以下是具体代码和输出:

数值代码实现
import torch

# 输入数据
values = torch.tensor([[0.5, 0.6, 0.7, 0.8]], dtype=torch.float32)  # 状态值
rewards = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32)  # 即时奖励
gamma = 0.99  # 折扣因子
lam = 0.95  # GAE 衰减因子

# 初始化变量
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]

# 按时间步反向计算 GAE
for t in reversed(range(length)):
    nextvalues = values[:, t + 1] if t < length - 1 else 0.0  # 未来的价值估计
    delta = rewards[:, t] + gamma * nextvalues - values[:, t]  # 即时误差
    lastgaelam = delta + gamma * lam * lastgaelam  # GAE 公式递推
    advantages_reversed.append(lastgaelam)

# 将 GAE 值逆序调整为正序
advantages = torch.stack(advantages_reversed[::-1], dim=1)

# 计算回报
returns = advantages + values

# 打印结果
print("优势函数 (Advantages):", advantages)
print("回报 (Returns):", returns)
运行结果

运行上述代码,输出如下:

优势函数 (Advantages): tensor([[1.5695, 1.0355, 0.5781, 0.2000]])
回报 (Returns): tensor([[2.0695, 1.6355, 1.2781, 1.0000]])
逐步验证计算

我们可以将代码计算过程与之前手动计算的步骤逐步对比:

  1. 时间步 ( t = 3 t=3 t=3 ):

    • 代码计算:
      ( δ 3 = 1 + 0 − 0.8 = 0.2 \delta_3 = 1 + 0 - 0.8 = 0.2 δ3=1+00.8=0.2 )
      ( A 3 = δ 3 = 0.2 A_3 = \delta_3 = 0.2 A3=δ3=0.2 )
    • 验证结果:一致
  2. 时间步 ( t = 2 t=2 t=2 ):

    • 代码计算:
      ( δ 2 = 1 + 0.99 ⋅ 0.8 − 0.7 = 0.39 \delta_2 = 1 + 0.99 \cdot 0.8 - 0.7 = 0.39 δ2=1+0.990.80.7=0.39 )
      ( A 2 = 0.39 + 0.99 ⋅ 0.95 ⋅ 0.2 = 0.5781 A_2 = 0.39 + 0.99 \cdot 0.95 \cdot 0.2 = 0.5781 A2=0.39+0.990.950.2=0.5781 )
    • 验证结果:一致
  3. 时间步 ( t = 1 t=1 t=1 ):

    • 代码计算:
      ( δ 1 = 1 + 0.99 ⋅ 0.7 − 0.6 = 0.49 \delta_1 = 1 + 0.99 \cdot 0.7 - 0.6 = 0.49 δ1=1+0.990.70.6=0.49 )
      ( A 1 = 0.49 + 0.99 ⋅ 0.95 ⋅ 0.5781 = 1.0355 A_1 = 0.49 + 0.99 \cdot 0.95 \cdot 0.5781 = 1.0355 A1=0.49+0.990.950.5781=1.0355 )
    • 验证结果:一致
  4. 时间步 ( t=0 ):

    • 代码计算:
      ( δ 0 = 1 + 0.99 ⋅ 0.6 − 0.5 = 0.594 \delta_0 = 1 + 0.99 \cdot 0.6 - 0.5 = 0.594 δ0=1+0.990.60.5=0.594 )
      ( A 0 = 0.594 + 0.99 ⋅ 0.95 ⋅ 1.0355 = 1.5695 A_0 = 0.594 + 0.99 \cdot 0.95 \cdot 1.0355 = 1.5695 A0=0.594+0.990.951.0355=1.5695 )
    • 验证结果:一致

最终,代码的计算过程与理论推导及手动计算完全一致。


【2】torch.stack(advantages_reversed[::-1], dim=1)代码解释

代码解析

torch.stack(advantages_reversed[::-1], dim=1) 是 PyTorch 中常用的张量操作,用于将多个张量按照特定维度组合成一个新的张量。

核心函数解释
  1. advantages_reversed

    • 这是一个 Python 列表,存储的是按照时间步逆序计算的 GAE 值 ( A t A_t At )。
    • 每个元素是一个 PyTorch 张量,形状为 [batch_size]
  2. [::-1]

    • Python 的切片操作,用于将列表逆序排列。
    • 从最后一个时间步开始重新排列 GAE 值,最终得到时间步正序的列表。
  3. torch.stack(tensors, dim)


输入形状分析

假设:

  • batch_size = 2(表示有两个并行采样的轨迹)
  • time_steps = 4(表示每个轨迹有 4 个时间步)

在逆序计算中,advantages_reversed 的结构如下(每个张量形状为 [batch_size]):

advantages_reversed = [
    torch.tensor([0.2, 0.3]),     # 对应时间步 t=3
    torch.tensor([0.5781, 0.63]), # 对应时间步 t=2
    torch.tensor([1.0355, 1.12]), # 对应时间步 t=1
    torch.tensor([1.5695, 1.65])  # 对应时间步 t=0
]

注意:

  • 列表中张量的顺序是逆序的(从时间步 t=3 到 t=0)。

输出形状分析

  1. 逆序排列转换为正序:[::-1]
advantages_reversed[::-1] = [
    torch.tensor([1.5695, 1.65]),  # 时间步 t=0
    torch.tensor([1.0355, 1.12]),  # 时间步 t=1
    torch.tensor([0.5781, 0.63]),  # 时间步 t=2
    torch.tensor([0.2, 0.3])       # 时间步 t=3
]
  1. 使用 torch.stack 沿着时间维度拼接
  • torch.stack(advantages_reversed[::-1], dim=1)[batch_size] 张量沿时间步维度拼接成 [batch_size, time_steps] 的二维张量。

最终结果为:

advantages = torch.stack(advantages_reversed[::-1], dim=1)

# 输出:
advantages = torch.tensor([
    [1.5695, 1.0355, 0.5781, 0.2],   # 第一条轨迹
    [1.65,   1.12,   0.63,   0.3]    # 第二条轨迹
])

输出形状: [batch_size, time_steps],即 [2, 4]



总结

  • 输入: advantages_reversed 是按时间步逆序的列表,每个元素形状为 [batch_size]
  • 输出: advantages 是按时间步正序的二维张量,形状为 [batch_size, time_steps]
  • 核心操作: 使用 torch.stack 沿时间步维度拼接,保留每个批次的轨迹数据独立排列。

这种操作适用于时间序列数据的处理,是实现 GAE 的关键步骤之一。

【3】这段代码中为什么返回advantages的时候用detach

在这段代码中,detach() 的作用是确保返回的 advantages 不会参与后续的梯度计算。这种做法在强化学习(特别是基于策略优化的方法,例如 PPO)中非常常见,目的是分离目标值的计算和梯度更新的过程。


1. detach() 的作用

  • 分离计算图detach() 会从当前计算图中分离张量,使其不参与后续的梯度传播。这意味着通过 detach() 得到的张量不会影响到用于生成它的张量的梯度。
  • 防止不必要的梯度计算:在强化学习中,advantages 是一个辅助变量,用于衡量当前策略相对于基线的优越性,但它本身并不是优化目标的直接一部分。因此,它的值不应该影响优化过程的梯度计算。

2. 为什么要对 advantages 使用 detach()

  • 计算 advantages 的来源

    • advantages 是通过时间差分误差(TD-error)计算得出的:
      δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)V(st)
      然后通过 GAE(广义优势估计)计算:
      A ^ t = δ t + γ λ A ^ t + 1 \hat{A}_t = \delta_t + \gamma \lambda \hat{A}_{t+1} A^t=δt+γλA^t+1
      这里的 deltalastgaelam 都依赖于 values(当前策略的状态值函数)。
  • 梯度更新目标的独立性

    • 在 PPO 的优化过程中,优势函数(advantages)被当作一种已经计算好的“常量”,用来调整策略梯度的方向,而不是优化目标的一部分。
    • 如果不使用 detach()advantages 会被看作当前计算图的一部分,这会导致梯度沿着整个计算图传播回去,影响到 values 和其他可能的模型参数的梯度计算。这种行为是不符合强化学习目标的。
  • 计算效率

    • 使用 detach() 可以避免不必要的计算图构建和额外的梯度计算,从而提高效率。

3. 为什么 returns 不使用 detach()

returns 是目标值的一部分,后续会用于计算值函数的损失(通常是 MSE 或 Huber 损失),从而更新值网络的参数:
loss = MSE ( V ( s t ) , returns ) \text{loss} = \text{MSE}(V(s_t), \text{returns}) loss=MSE(V(st),returns)
因此,需要保留计算图以确保梯度能够传播到 values,从而优化模型。


4. 示例解释

假设没有 detach()

advantages_no_detach = torch.stack(advantages_reversed[::-1], dim=1)

如果将 advantages_no_detach 用于策略优化:

policy_loss = -(log_probs * advantages_no_detach).mean()
policy_loss.backward()

这种情况下,梯度会沿着 advantages_no_detach 的计算路径传播回 values,这会干扰值函数的更新。

添加 detach() 后:

advantages = advantages_no_detach.detach()

此时,advantages 会被视为常量,确保它仅作为策略梯度的一个系数,而不会影响 values 的计算图。


5. 总结

  • advantages 使用 detach() 是为了防止梯度传播,确保优势函数仅作为策略优化的“参考值”,而不干扰值网络的梯度计算。
  • returns 不使用 detach() 是为了保留计算图,以便通过它更新值函数的参数。

这种设计符合强化学习的理论和实践需求,避免了不必要的梯度干扰,同时确保了损失函数的正确优化路径。

后记

2024年12月12日22点49分于上海,在GPT4o大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值