深入解析 PPO 算法 —— 结合 Stable-Baselines3 源码
1. PPO 简介
PPO(Proximal Policy Optimization,近端策略优化)是一种常用的强化学习算法,由 OpenAI 在 2017 年提出。它在策略梯度(Policy Gradient)方法的基础上,通过引入 剪切(Clipping) 技术来限制策略更新的幅度,以提高训练的稳定性和数据效率。
PPO 主要有两种变体:
- PPO-Clip(裁剪版本):限制策略更新的范围,避免过大的策略变化。
- PPO-Penalty(KL 惩罚版本):在目标函数中加入 KL 散度(Kullback-Leibler Divergence)作为正则项。
Stable-Baselines3 采用的是 PPO-Clip 版本,我们将结合代码深入解析其实现细节。
介绍PPO算法原理博客详见:强化学习之PPO算法
2. PPO 在 Stable-Baselines3 中的实现
2.1 训练流程概览
PPO 继承自 OnPolicyAlgorithm
,主要包括以下核心步骤:
- 收集数据(Rollout Collection)
- 计算优势函数(GAE-Lambda)
- 更新策略网络(Policy Update)
- 更新价值网络(Value Function Update)
- 优化目标(Clipped Objective)
- 计算 KL 散度,判断是否提前停止
PPO 的训练逻辑主要在 PPO.train()
方法中实现,我们将逐步解析该方法。
2.2 代码解析:PPO 训练过程
(1)初始化训练参数
self.policy.set_training_mode(True)
self._update_learning_rate(self.policy.optimizer)
clip_range = self.clip_range(self._current_progress_remaining)
self.policy.set_training_mode(True)
: 让模型进入训练模式,影响 Batch Normalization 和 Dropout。self._update_learning_rate()
: 根据当前训练进度调整学习率。clip_range
: 计算裁剪范围,控制策略的更新幅度。
(2)遍历数据批次,计算损失
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
actions = rollout_data.actions.long().flatten()
rollout_data
存储了当前收集的样本,包括状态observations
、动作actions
、优势值advantages
等。- 如果是离散动作空间,则将动作转换为
long
类型。
(3)计算策略网络的损失
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
-
self.policy.evaluate_actions()
: 计算新策略下的 动作概率对数值(log_prob) 和 状态价值(values)。 -
ratio = th.exp(log_prob - rollout_data.old_log_prob)
: 计算新旧策略概率比,衡量策略的变化幅度。 -
不同动作空间的log_prob的计算详见:强化学习之离散动作采样 vs 连续动作采样
PPO 采用 裁剪(Clipping) 技术,目标函数为:
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
policy_loss_1
:标准策略梯度损失。policy_loss_2
:裁剪后的损失,确保策略更新不会偏离过远。policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
:选取两者的最小值作为损失,以保持策略更新的稳定性。
(4)计算值函数损失
value_loss = F.mse_loss(rollout_data.returns, values_pred)
- 价值网络使用 均方误差(MSE Loss) 来计算
values_pred
和rollout_data.returns
之间的差异。
(5)计算熵损失,鼓励探索
entropy_loss = -th.mean(entropy)
- 熵损失用于增加探索行为。
(6)计算总损失,并执行梯度更新
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
self.policy.optimizer.zero_grad()
loss.backward()
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
- 总损失 = 策略损失 + 值函数损失 + 熵损失。
zero_grad()
: 清空梯度。backward()
: 计算梯度。clip_grad_norm_()
: 限制梯度范数,防止梯度爆炸。optimizer.step()
: 更新模型参数。
(7)KL 散度监测,决定是否提前终止
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
- 计算 KL 散度,如果变化过大,提前终止训练,避免策略收敛到次优解。
2.3 代码解析:数据收集过程
在 OnPolicyAlgorithm.collect_rollouts()
方法中,PPO 通过环境交互收集训练数据。
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
new_obs, rewards, dones, infos = env.step(actions.cpu().numpy())
self.policy(obs_tensor)
: 计算策略输出(动作、值函数、log 概率)。env.step()
: 在环境中执行动作,获取新的观察值、奖励等。
数据存入 rollout_buffer
,并在训练阶段计算 GAE 优势函数,以及Return。
Advantage 以及 Return 介绍详见:强化学习之Advantage优势函数
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
3. 总结
Stable-Baselines3 中的 PPO 通过 裁剪目标函数(Clipping Objective) 来稳定策略更新,并使用 KL 散度早停(KL Divergence Early Stopping) 机制避免策略崩溃。
PPO 的核心思想:
- 限制策略变化幅度,提高训练稳定性。
- 优势估计(GAE-Lambda),减少方差,提高采样效率。
- 熵奖励,鼓励探索,防止策略过早收敛。
- 梯度裁剪,防止梯度爆炸,提高训练稳定性。
在实践中,PPO 适用于 大规模连续/离散动作空间任务,并且在 机器人控制、游戏 AI、金融交易等场景 中广泛应用。
4. 附录
stable-baseline3中PPO算法主要相关代码
# https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/ppo/ppo.py
class PPO(OnPolicyAlgorithm):
def train(self) -> None:
"""
Update policy using the currently gathered rollout buffer.
"""
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
entropy_losses = []
pg_losses, value_losses = [], []
clip_fractions = []
continue_training = True
# train for n_epochs epochs
for epoch in range(self.n_epochs):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
# Normalization does not make sense if mini batchsize == 1, see GH issue #325
if self.normalize_advantage and len(advantages) > 1:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
# Logging
pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
# No clipping
values_pred = values
else:
# Clip the difference between old and new value
# NOTE: this depends on the reward scaling
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
)
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
# Entropy loss favor exploration
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Calculate approximate form of reverse KL Divergence for early stopping
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
# and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
self._n_updates += 1
if not continue_training:
break
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
# Logs
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
self.logger.record("train/loss", loss.item())
self.logger.record("train/explained_variance", explained_var)
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range)
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)
class OnPolicyAlgorithm(BaseAlgorithm):
def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
n_rollout_steps: int,
) -> bool:
"""
Collect experiences using the current policy and fill a ``RolloutBuffer``.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param rollout_buffer: Buffer to fill with rollouts
:param n_rollout_steps: Number of experiences to collect per environment
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""
assert self._last_obs is not None, "No previous observation was provided"
# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)
n_steps = 0
rollout_buffer.reset()
# Sample new weights for the state dependent exploration
if self.use_sde:
self.policy.reset_noise(env.num_envs)
callback.on_rollout_start()
while n_steps < n_rollout_steps:
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.policy.reset_noise(env.num_envs)
with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
actions = actions.cpu().numpy()
# Rescale and perform action
clipped_actions = actions
if isinstance(self.action_space, spaces.Box):
if self.policy.squash_output:
# Unscale the actions to match env bounds
# if they were previously squashed (scaled in [-1, 1])
clipped_actions = self.policy.unscale_action(clipped_actions)
else:
# Otherwise, clip the actions to avoid out of bound error
# as we are sampling from an unbounded Gaussian distribution
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions)
self.num_timesteps += env.num_envs
# Give access to local variables
callback.update_locals(locals())
if not callback.on_step():
return False
self._update_info_buffer(infos, dones)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
# Handle timeout by bootstrapping with value function
# see GitHub issue #633
for idx, done in enumerate(dones):
if (
done
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type]
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(
self._last_obs, # type: ignore[arg-type]
actions,
rewards,
self._last_episode_starts, # type: ignore[arg-type]
values,
log_probs,
)
self._last_obs = new_obs # type: ignore[assignment]
self._last_episode_starts = dones
with th.no_grad():
# Compute value for the last timestep
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
callback.update_locals(locals())
callback.on_rollout_end()
return True
def learn(
self: SelfOnPolicyAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "OnPolicyAlgorithm",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOnPolicyAlgorithm:
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())
assert self.env is not None
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if not continue_training:
break
iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
self.dump_logs(iteration)
self.train()
callback.on_training_end()
return self