强化学习之PPO源码阅读

深入解析 PPO 算法 —— 结合 Stable-Baselines3 源码

1. PPO 简介

PPO(Proximal Policy Optimization,近端策略优化)是一种常用的强化学习算法,由 OpenAI 在 2017 年提出。它在策略梯度(Policy Gradient)方法的基础上,通过引入 剪切(Clipping) 技术来限制策略更新的幅度,以提高训练的稳定性和数据效率。

PPO 主要有两种变体:

  • PPO-Clip(裁剪版本):限制策略更新的范围,避免过大的策略变化。
  • PPO-Penalty(KL 惩罚版本):在目标函数中加入 KL 散度(Kullback-Leibler Divergence)作为正则项。

Stable-Baselines3 采用的是 PPO-Clip 版本,我们将结合代码深入解析其实现细节。
介绍PPO算法原理博客详见:强化学习之PPO算法


2. PPO 在 Stable-Baselines3 中的实现

2.1 训练流程概览

PPO 继承自 OnPolicyAlgorithm,主要包括以下核心步骤:

  1. 收集数据(Rollout Collection)
  2. 计算优势函数(GAE-Lambda)
  3. 更新策略网络(Policy Update)
  4. 更新价值网络(Value Function Update)
  5. 优化目标(Clipped Objective)
  6. 计算 KL 散度,判断是否提前停止

PPO 的训练逻辑主要在 PPO.train() 方法中实现,我们将逐步解析该方法。


2.2 代码解析:PPO 训练过程

(1)初始化训练参数
self.policy.set_training_mode(True)
self._update_learning_rate(self.policy.optimizer)
clip_range = self.clip_range(self._current_progress_remaining)
  • self.policy.set_training_mode(True): 让模型进入训练模式,影响 Batch Normalization 和 Dropout。
  • self._update_learning_rate(): 根据当前训练进度调整学习率。
  • clip_range: 计算裁剪范围,控制策略的更新幅度。
(2)遍历数据批次,计算损失
for rollout_data in self.rollout_buffer.get(self.batch_size):
    actions = rollout_data.actions
    if isinstance(self.action_space, spaces.Discrete):
        actions = rollout_data.actions.long().flatten()
  • rollout_data 存储了当前收集的样本,包括状态 observations、动作 actions、优势值 advantages 等。
  • 如果是离散动作空间,则将动作转换为 long 类型。
(3)计算策略网络的损失
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
ratio = th.exp(log_prob - rollout_data.old_log_prob)
  • self.policy.evaluate_actions(): 计算新策略下的 动作概率对数值(log_prob)状态价值(values)

  • ratio = th.exp(log_prob - rollout_data.old_log_prob): 计算新旧策略概率比,衡量策略的变化幅度。

  • 不同动作空间的log_prob的计算详见:强化学习之离散动作采样 vs 连续动作采样

PPO 采用 裁剪(Clipping) 技术,目标函数为:

policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
  • policy_loss_1:标准策略梯度损失。
  • policy_loss_2:裁剪后的损失,确保策略更新不会偏离过远。
  • policy_loss = -th.min(policy_loss_1, policy_loss_2).mean():选取两者的最小值作为损失,以保持策略更新的稳定性。
(4)计算值函数损失
value_loss = F.mse_loss(rollout_data.returns, values_pred)
  • 价值网络使用 均方误差(MSE Loss) 来计算 values_predrollout_data.returns 之间的差异。
(5)计算熵损失,鼓励探索
entropy_loss = -th.mean(entropy)
  • 熵损失用于增加探索行为。
(6)计算总损失,并执行梯度更新
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
self.policy.optimizer.zero_grad()
loss.backward()
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
  • 总损失 = 策略损失 + 值函数损失 + 熵损失
  • zero_grad(): 清空梯度。
  • backward(): 计算梯度。
  • clip_grad_norm_(): 限制梯度范数,防止梯度爆炸。
  • optimizer.step(): 更新模型参数。
(7)KL 散度监测,决定是否提前终止
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
    continue_training = False
    print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
  • 计算 KL 散度,如果变化过大,提前终止训练,避免策略收敛到次优解。

2.3 代码解析:数据收集过程

OnPolicyAlgorithm.collect_rollouts() 方法中,PPO 通过环境交互收集训练数据。

obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
new_obs, rewards, dones, infos = env.step(actions.cpu().numpy())
  • self.policy(obs_tensor): 计算策略输出(动作、值函数、log 概率)。
  • env.step(): 在环境中执行动作,获取新的观察值、奖励等。

数据存入 rollout_buffer,并在训练阶段计算 GAE 优势函数,以及Return

Advantage 以及 Return 介绍详见:强化学习之Advantage优势函数

rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

3. 总结

Stable-Baselines3 中的 PPO 通过 裁剪目标函数(Clipping Objective) 来稳定策略更新,并使用 KL 散度早停(KL Divergence Early Stopping) 机制避免策略崩溃。

PPO 的核心思想:

  • 限制策略变化幅度,提高训练稳定性。
  • 优势估计(GAE-Lambda),减少方差,提高采样效率。
  • 熵奖励,鼓励探索,防止策略过早收敛。
  • 梯度裁剪,防止梯度爆炸,提高训练稳定性。

在实践中,PPO 适用于 大规模连续/离散动作空间任务,并且在 机器人控制、游戏 AI、金融交易等场景 中广泛应用。

4. 附录

stable-baseline3中PPO算法主要相关代码

# https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/ppo/ppo.py
class PPO(OnPolicyAlgorithm):
    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)  # type: ignore[operator]
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)  # type: ignore[operator]

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True
        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
                values = values.flatten()
                # Normalize advantage
                advantages = rollout_data.advantages
                # Normalization does not make sense if mini batchsize == 1, see GH issue #325
                if self.normalize_advantage and len(advantages) > 1:
                    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob)

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None:
                    # No clipping
                    values_pred = values
                else:
                    # Clip the difference between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred)
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else:
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.policy.optimizer.step()

            self._n_updates += 1
            if not continue_training:
                break

        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

class OnPolicyAlgorithm(BaseAlgorithm):
    def collect_rollouts(
        self,
        env: VecEnv,
        callback: BaseCallback,
        rollout_buffer: RolloutBuffer,
        n_rollout_steps: int,
    ) -> bool:
        """
        Collect experiences using the current policy and fill a ``RolloutBuffer``.
        The term rollout here refers to the model-free notion and should not
        be used with the concept of rollout used in model-based RL or planning.

        :param env: The training environment
        :param callback: Callback that will be called at each step
            (and at the beginning and end of the rollout)
        :param rollout_buffer: Buffer to fill with rollouts
        :param n_rollout_steps: Number of experiences to collect per environment
        :return: True if function returned with at least `n_rollout_steps`
            collected, False if callback terminated rollout prematurely.
        """
        assert self._last_obs is not None, "No previous observation was provided"
        # Switch to eval mode (this affects batch norm / dropout)
        self.policy.set_training_mode(False)

        n_steps = 0
        rollout_buffer.reset()
        # Sample new weights for the state dependent exploration
        if self.use_sde:
            self.policy.reset_noise(env.num_envs)

        callback.on_rollout_start()

        while n_steps < n_rollout_steps:
            if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
                # Sample a new noise matrix
                self.policy.reset_noise(env.num_envs)

            with th.no_grad():
                # Convert to pytorch tensor or to TensorDict
                obs_tensor = obs_as_tensor(self._last_obs, self.device)
                actions, values, log_probs = self.policy(obs_tensor)
            actions = actions.cpu().numpy()

            # Rescale and perform action
            clipped_actions = actions

            if isinstance(self.action_space, spaces.Box):
                if self.policy.squash_output:
                    # Unscale the actions to match env bounds
                    # if they were previously squashed (scaled in [-1, 1])
                    clipped_actions = self.policy.unscale_action(clipped_actions)
                else:
                    # Otherwise, clip the actions to avoid out of bound error
                    # as we are sampling from an unbounded Gaussian distribution
                    clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)

            new_obs, rewards, dones, infos = env.step(clipped_actions)

            self.num_timesteps += env.num_envs

            # Give access to local variables
            callback.update_locals(locals())
            if not callback.on_step():
                return False

            self._update_info_buffer(infos, dones)
            n_steps += 1

            if isinstance(self.action_space, spaces.Discrete):
                # Reshape in case of discrete action
                actions = actions.reshape(-1, 1)

            # Handle timeout by bootstrapping with value function
            # see GitHub issue #633
            for idx, done in enumerate(dones):
                if (
                    done
                    and infos[idx].get("terminal_observation") is not None
                    and infos[idx].get("TimeLimit.truncated", False)
                ):
                    terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
                    with th.no_grad():
                        terminal_value = self.policy.predict_values(terminal_obs)[0]  # type: ignore[arg-type]
                    rewards[idx] += self.gamma * terminal_value

            rollout_buffer.add(
                self._last_obs,  # type: ignore[arg-type]
                actions,
                rewards,
                self._last_episode_starts,  # type: ignore[arg-type]
                values,
                log_probs,
            )
            self._last_obs = new_obs  # type: ignore[assignment]
            self._last_episode_starts = dones

        with th.no_grad():
            # Compute value for the last timestep
            values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))  # type: ignore[arg-type]

        rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

        callback.update_locals(locals())

        callback.on_rollout_end()

        return True

    def learn(
        self: SelfOnPolicyAlgorithm,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 1,
        tb_log_name: str = "OnPolicyAlgorithm",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SelfOnPolicyAlgorithm:
        iteration = 0

        total_timesteps, callback = self._setup_learn(
            total_timesteps,
            callback,
            reset_num_timesteps,
            tb_log_name,
            progress_bar,
        )

        callback.on_training_start(locals(), globals())

        assert self.env is not None

        while self.num_timesteps < total_timesteps:
            continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)

            if not continue_training:
                break

            iteration += 1
            self._update_current_progress_remaining(self.num_timesteps, total_timesteps)

            # Display training infos
            if log_interval is not None and iteration % log_interval == 0:
                assert self.ep_info_buffer is not None
                self.dump_logs(iteration)

            self.train()

        callback.on_training_end()

        return self
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值