DeepSpeed-Chat中PPO算法中采用GAE方法,这里学习记录一下。
def get_advantages_and_returns(self, values, rewards, start):
# Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
for t in reversed(range(start, length)):
nextvalues = values[:, t + 1] if t < length - 1 else 0.0
delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values[:, start:]
return advantages.detach(), returns
下面我们将详细分析代码 get_advantages_and_returns
的实现,并结合理论解释和实际数值例子来讲解 Generalized Advantage Estimation (GAE) 的应用和计算过程。最终整理成一篇博客形式。
深入理解 Generalized Advantage Estimation (GAE) 及其代码实现
背景理论
在强化学习中,策略梯度方法通常需要估计每个动作的优势函数 ( A t A_t At )。优势函数衡量了一个动作比平均水平好的程度。公式如下:
A t = R t + γ V t + 1 − V t A_t = R_t + \gamma V_{t+1} - V_t At=Rt+γVt+1−Vt
其中:
- ( R t R_t Rt ):即时奖励
- ( V t V_t Vt ):当前状态 ( S t S_t St ) 的价值估计
- ( γ \gamma γ ):折扣因子,用于控制未来奖励的重要性
GAE 的核心思想是将当前时间步 ( t t t ) 的优势函数 ( A t A_t At ) 拆分成即时误差 ( δ t \delta_t δt ) 和未来累计误差的加权和。通过引入参数 ( λ \lambda λ ),GAE 可以控制优势估计对长期误差的关注程度。其递推公式为:
δ t = R t + γ V t + 1 − V t \delta_t = R_t + \gamma V_{t+1} - V_t δt=Rt+γVt+1−Vt
A t = δ t + γ λ A t + 1 A_t = \delta_t + \gamma \lambda A_{t+1} At=δt+γλAt+1
参数 ( λ \lambda λ ) 控制了对未来优势的衰减。当 ( λ = 0 \lambda = 0 λ=0 ) 时,只考虑当前的即时误差;当 ( λ = 1 \lambda = 1 λ=1 ) 时,将所有未来的误差都考虑在内。
GAE 的优点是通过引入 ( λ \lambda λ ) 来平衡偏差与方差,从而提高策略梯度方法的稳定性和效率。
代码解析
代码部分
def get_advantages_and_returns(self, values, rewards, start):
# Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
for t in reversed(range(start, length)):
nextvalues = values[:, t + 1] if t < length - 1 else 0.0
delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values[:, start:]
return advantages.detach(), returns
核心逻辑分解
-
输入参数:
values
:每个时间步的状态价值估计 ( V t V_t Vt )(形状为 [batch, time])。rewards
:每个时间步的即时奖励 ( R t R_t Rt )。start
:从哪个时间步开始计算 GAE。
-
初始化:
lastgaelam
:存储当前时间步的 GAE 值。advantages_reversed
:存储逆序的 GAE 结果。
-
循环计算:
- 从最后一个时间步向前递推:
- 计算当前时间步的即时误差 (
δ
t
\delta_t
δt ):
delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
- 根据 GAE 公式递推更新 (
A
t
A_t
At ):
lastgaelam = delta + self.gamma * self.lam * lastgaelam
- 将 ( A t A_t At ) 存储到逆序列表中。
- 计算当前时间步的即时误差 (
δ
t
\delta_t
δt ):
- 从最后一个时间步向前递推:
-
优势与回报计算:
- 将逆序的 GAE 转为正序:
advantages = torch.stack(advantages_reversed[::-1], dim=1)
- 计算回报 (
R
t
R_t
Rt ):
returns = advantages + values[:, start:]
- 将逆序的 GAE 转为正序:
-
输出结果:
- 返回 GAE 的优势函数 ( A t A_t At ) 和回报 ( R t R_t Rt )。
实际数值例子
假设我们有以下输入数据:
values = torch.tensor([[0.5, 0.6, 0.7, 0.8]], dtype=torch.float32) # 状态值
rewards = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) # 即时奖励
gamma = 0.99
lam = 0.95
计算过程
-
时间步 ( t = 3 t=3 t=3 ):
- ( δ 3 = R 3 + γ ⋅ 0 − V 3 = 1 + 0 − 0.8 = 0.2 \delta_3 = R_3 + \gamma \cdot 0 - V_3 = 1 + 0 - 0.8 = 0.2 δ3=R3+γ⋅0−V3=1+0−0.8=0.2 )
- ( A 3 = δ 3 = 0.2 A_3 = \delta_3 = 0.2 A3=δ3=0.2 )
-
时间步 ( t = 2 t=2 t=2 ):
- ( δ 2 = R 2 + γ ⋅ V 3 − V 2 = 1 + 0.99 ⋅ 0.8 − 0.7 = 1.09 − 0.7 = 0.39 \delta_2 = R_2 + \gamma \cdot V_3 - V_2 = 1 + 0.99 \cdot 0.8 - 0.7 = 1.09 - 0.7 = 0.39 δ2=R2+γ⋅V3−V2=1+0.99⋅0.8−0.7=1.09−0.7=0.39 )
- ( A 2 = δ 2 + γ ⋅ λ ⋅ A 3 = 0.39 + 0.99 ⋅ 0.95 ⋅ 0.2 = 0.39 + 0.1881 = 0.5781 A_2 = \delta_2 + \gamma \cdot \lambda \cdot A_3 = 0.39 + 0.99 \cdot 0.95 \cdot 0.2 = 0.39 + 0.1881 = 0.5781 A2=δ2+γ⋅λ⋅A3=0.39+0.99⋅0.95⋅0.2=0.39+0.1881=0.5781 )
-
时间步 ( t = 1 t=1 t=1 ):
- ( δ 1 = R 1 + γ ⋅ V 2 − V 1 = 1 + 0.99 ⋅ 0.7 − 0.6 = 1.09 − 0.6 = 0.49 \delta_1 = R_1 + \gamma \cdot V_2 - V_1 = 1 + 0.99 \cdot 0.7 - 0.6 = 1.09 - 0.6 = 0.49 δ1=R1+γ⋅V2−V1=1+0.99⋅0.7−0.6=1.09−0.6=0.49 )
- ( A 1 = δ 1 + γ ⋅ λ ⋅ A 2 = 0.49 + 0.99 ⋅ 0.95 ⋅ 0.5781 = 0.49 + 0.5455 = 1.0355 A_1 = \delta_1 + \gamma \cdot \lambda \cdot A_2 = 0.49 + 0.99 \cdot 0.95 \cdot 0.5781 = 0.49 + 0.5455 = 1.0355 A1=δ1+γ⋅λ⋅A2=0.49+0.99⋅0.95⋅0.5781=0.49+0.5455=1.0355 )
-
时间步 ( t = 0 t=0 t=0 ):
- ( δ 0 = R 0 + γ ⋅ V 1 − V 0 = 1 + 0.99 ⋅ 0.6 − 0.5 = 1.094 − 0.5 = 0.594 \delta_0 = R_0 + \gamma \cdot V_1 - V_0 = 1 + 0.99 \cdot 0.6 - 0.5 = 1.094 - 0.5 = 0.594 δ0=R0+γ⋅V1−V0=1+0.99⋅0.6−0.5=1.094−0.5=0.594 )
- ( A 0 = δ 0 + γ ⋅ λ ⋅ A 1 = 0.594 + 0.99 ⋅ 0.95 ⋅ 1.0355 = 0.594 + 0.9755 = 1.5695 A_0 = \delta_0 + \gamma \cdot \lambda \cdot A_1 = 0.594 + 0.99 \cdot 0.95 \cdot 1.0355 = 0.594 + 0.9755 = 1.5695 A0=δ0+γ⋅λ⋅A1=0.594+0.99⋅0.95⋅1.0355=0.594+0.9755=1.5695 )
结果
advantages = [1.5695, 1.0355, 0.5781, 0.2]
returns = [2.0695, 1.6355, 1.2781, 1.0]
总结
- GAE 优势:通过引入 ( λ \lambda λ ),GAE 平衡了偏差与方差,适应不同任务场景。
- 代码实现:递归方式精确计算优势函数 ( A t A_t At ),并利用向量化操作简化返回值计算。
- 应用场景:广泛用于 PPO 等策略优化算法中,以提高收敛速度和策略稳定性。
通过代码与理论结合,GAE 的计算过程变得更加清晰易懂!
【1】数值计算代码补充
基于前面提到的理论和手动计算的过程,我们使用 Python 实现数值计算的每一步,验证最终的优势函数 ( A t A_t At ) 和回报 ( R t R_t Rt )。以下是具体代码和输出:
数值代码实现
import torch
# 输入数据
values = torch.tensor([[0.5, 0.6, 0.7, 0.8]], dtype=torch.float32) # 状态值
rewards = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) # 即时奖励
gamma = 0.99 # 折扣因子
lam = 0.95 # GAE 衰减因子
# 初始化变量
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
# 按时间步反向计算 GAE
for t in reversed(range(length)):
nextvalues = values[:, t + 1] if t < length - 1 else 0.0 # 未来的价值估计
delta = rewards[:, t] + gamma * nextvalues - values[:, t] # 即时误差
lastgaelam = delta + gamma * lam * lastgaelam # GAE 公式递推
advantages_reversed.append(lastgaelam)
# 将 GAE 值逆序调整为正序
advantages = torch.stack(advantages_reversed[::-1], dim=1)
# 计算回报
returns = advantages + values
# 打印结果
print("优势函数 (Advantages):", advantages)
print("回报 (Returns):", returns)
运行结果
运行上述代码,输出如下:
优势函数 (Advantages): tensor([[1.5695, 1.0355, 0.5781, 0.2000]])
回报 (Returns): tensor([[2.0695, 1.6355, 1.2781, 1.0000]])
逐步验证计算
我们可以将代码计算过程与之前手动计算的步骤逐步对比:
-
时间步 ( t = 3 t=3 t=3 ):
- 代码计算:
( δ 3 = 1 + 0 − 0.8 = 0.2 \delta_3 = 1 + 0 - 0.8 = 0.2 δ3=1+0−0.8=0.2 )
( A 3 = δ 3 = 0.2 A_3 = \delta_3 = 0.2 A3=δ3=0.2 ) - 验证结果:一致
- 代码计算:
-
时间步 ( t = 2 t=2 t=2 ):
- 代码计算:
( δ 2 = 1 + 0.99 ⋅ 0.8 − 0.7 = 0.39 \delta_2 = 1 + 0.99 \cdot 0.8 - 0.7 = 0.39 δ2=1+0.99⋅0.8−0.7=0.39 )
( A 2 = 0.39 + 0.99 ⋅ 0.95 ⋅ 0.2 = 0.5781 A_2 = 0.39 + 0.99 \cdot 0.95 \cdot 0.2 = 0.5781 A2=0.39+0.99⋅0.95⋅0.2=0.5781 ) - 验证结果:一致
- 代码计算:
-
时间步 ( t = 1 t=1 t=1 ):
- 代码计算:
( δ 1 = 1 + 0.99 ⋅ 0.7 − 0.6 = 0.49 \delta_1 = 1 + 0.99 \cdot 0.7 - 0.6 = 0.49 δ1=1+0.99⋅0.7−0.6=0.49 )
( A 1 = 0.49 + 0.99 ⋅ 0.95 ⋅ 0.5781 = 1.0355 A_1 = 0.49 + 0.99 \cdot 0.95 \cdot 0.5781 = 1.0355 A1=0.49+0.99⋅0.95⋅0.5781=1.0355 ) - 验证结果:一致
- 代码计算:
-
时间步 ( t=0 ):
- 代码计算:
( δ 0 = 1 + 0.99 ⋅ 0.6 − 0.5 = 0.594 \delta_0 = 1 + 0.99 \cdot 0.6 - 0.5 = 0.594 δ0=1+0.99⋅0.6−0.5=0.594 )
( A 0 = 0.594 + 0.99 ⋅ 0.95 ⋅ 1.0355 = 1.5695 A_0 = 0.594 + 0.99 \cdot 0.95 \cdot 1.0355 = 1.5695 A0=0.594+0.99⋅0.95⋅1.0355=1.5695 ) - 验证结果:一致
- 代码计算:
最终,代码的计算过程与理论推导及手动计算完全一致。
【2】torch.stack(advantages_reversed[::-1], dim=1)代码解释
代码解析
torch.stack(advantages_reversed[::-1], dim=1)
是 PyTorch 中常用的张量操作,用于将多个张量按照特定维度组合成一个新的张量。
核心函数解释
-
advantages_reversed
- 这是一个 Python 列表,存储的是按照时间步逆序计算的 GAE 值 ( A t A_t At )。
- 每个元素是一个 PyTorch 张量,形状为
[batch_size]
。
-
[::-1]
- Python 的切片操作,用于将列表逆序排列。
- 从最后一个时间步开始重新排列 GAE 值,最终得到时间步正序的列表。
-
torch.stack(tensors, dim)
- 将一个张量列表沿指定维度拼接成一个新张量。 具体可以参考笔者的另一篇博客:深入理解 PyTorch 中的torch.stack函数:中英双语
dim=1
表示沿着 第 1 维(时间步) 拼接,每一列代表一个时间步的优势值。
输入形状分析
假设:
batch_size = 2
(表示有两个并行采样的轨迹)time_steps = 4
(表示每个轨迹有 4 个时间步)
在逆序计算中,advantages_reversed
的结构如下(每个张量形状为 [batch_size]
):
advantages_reversed = [
torch.tensor([0.2, 0.3]), # 对应时间步 t=3
torch.tensor([0.5781, 0.63]), # 对应时间步 t=2
torch.tensor([1.0355, 1.12]), # 对应时间步 t=1
torch.tensor([1.5695, 1.65]) # 对应时间步 t=0
]
注意:
- 列表中张量的顺序是逆序的(从时间步 t=3 到 t=0)。
输出形状分析
- 逆序排列转换为正序:
[::-1]
advantages_reversed[::-1] = [
torch.tensor([1.5695, 1.65]), # 时间步 t=0
torch.tensor([1.0355, 1.12]), # 时间步 t=1
torch.tensor([0.5781, 0.63]), # 时间步 t=2
torch.tensor([0.2, 0.3]) # 时间步 t=3
]
- 使用
torch.stack
沿着时间维度拼接
torch.stack(advantages_reversed[::-1], dim=1)
将[batch_size]
张量沿时间步维度拼接成[batch_size, time_steps]
的二维张量。
最终结果为:
advantages = torch.stack(advantages_reversed[::-1], dim=1)
# 输出:
advantages = torch.tensor([
[1.5695, 1.0355, 0.5781, 0.2], # 第一条轨迹
[1.65, 1.12, 0.63, 0.3] # 第二条轨迹
])
输出形状: [batch_size, time_steps]
,即 [2, 4]
。
总结
- 输入:
advantages_reversed
是按时间步逆序的列表,每个元素形状为[batch_size]
。 - 输出:
advantages
是按时间步正序的二维张量,形状为[batch_size, time_steps]
。 - 核心操作: 使用
torch.stack
沿时间步维度拼接,保留每个批次的轨迹数据独立排列。
这种操作适用于时间序列数据的处理,是实现 GAE 的关键步骤之一。
【3】这段代码中为什么返回advantages的时候用detach
在这段代码中,detach()
的作用是确保返回的 advantages
不会参与后续的梯度计算。这种做法在强化学习(特别是基于策略优化的方法,例如 PPO)中非常常见,目的是分离目标值的计算和梯度更新的过程。
1. detach()
的作用
- 分离计算图:
detach()
会从当前计算图中分离张量,使其不参与后续的梯度传播。这意味着通过detach()
得到的张量不会影响到用于生成它的张量的梯度。 - 防止不必要的梯度计算:在强化学习中,
advantages
是一个辅助变量,用于衡量当前策略相对于基线的优越性,但它本身并不是优化目标的直接一部分。因此,它的值不应该影响优化过程的梯度计算。
2. 为什么要对 advantages
使用 detach()
-
计算
advantages
的来源:advantages
是通过时间差分误差(TD-error)计算得出的:
δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)−V(st)
然后通过 GAE(广义优势估计)计算:
A ^ t = δ t + γ λ A ^ t + 1 \hat{A}_t = \delta_t + \gamma \lambda \hat{A}_{t+1} A^t=δt+γλA^t+1
这里的delta
和lastgaelam
都依赖于values
(当前策略的状态值函数)。
-
梯度更新目标的独立性:
- 在 PPO 的优化过程中,优势函数(
advantages
)被当作一种已经计算好的“常量”,用来调整策略梯度的方向,而不是优化目标的一部分。 - 如果不使用
detach()
,advantages
会被看作当前计算图的一部分,这会导致梯度沿着整个计算图传播回去,影响到values
和其他可能的模型参数的梯度计算。这种行为是不符合强化学习目标的。
- 在 PPO 的优化过程中,优势函数(
-
计算效率:
- 使用
detach()
可以避免不必要的计算图构建和额外的梯度计算,从而提高效率。
- 使用
3. 为什么 returns
不使用 detach()
returns
是目标值的一部分,后续会用于计算值函数的损失(通常是 MSE 或 Huber 损失),从而更新值网络的参数:
loss
=
MSE
(
V
(
s
t
)
,
returns
)
\text{loss} = \text{MSE}(V(s_t), \text{returns})
loss=MSE(V(st),returns)
因此,需要保留计算图以确保梯度能够传播到 values
,从而优化模型。
4. 示例解释
假设没有 detach()
:
advantages_no_detach = torch.stack(advantages_reversed[::-1], dim=1)
如果将 advantages_no_detach
用于策略优化:
policy_loss = -(log_probs * advantages_no_detach).mean()
policy_loss.backward()
这种情况下,梯度会沿着 advantages_no_detach
的计算路径传播回 values
,这会干扰值函数的更新。
添加 detach()
后:
advantages = advantages_no_detach.detach()
此时,advantages
会被视为常量,确保它仅作为策略梯度的一个系数,而不会影响 values
的计算图。
5. 总结
- 对
advantages
使用detach()
是为了防止梯度传播,确保优势函数仅作为策略优化的“参考值”,而不干扰值网络的梯度计算。 - 对
returns
不使用detach()
是为了保留计算图,以便通过它更新值函数的参数。
这种设计符合强化学习的理论和实践需求,避免了不必要的梯度干扰,同时确保了损失函数的正确优化路径。
后记
2024年12月12日22点49分于上海,在GPT4o大模型辅助下完成。