【强化学习】代码实战:在hopper上实现PPO算法

一、环境搭建

我的环境:windows10,无gpu.

下载mujoco、mujoco-py:跟着大佬下载就行——【Mujoco】在Win10下的安装 - 知乎 (zhihu.com)

每一步都要仔细完成哦,不然很容易错的!过程中遇到报错可以看看评论区,说不定有解决方法。

下载stable baseline3:这是官方文档示例 — 稳定基线3 2.4.0a8 文档 --- Examples — Stable Baselines3 2.4.0a8 documentation (stable-baselines3.readthedocs.io)

可以在创建的虚拟环境py37(上文教程中创建的虚拟环境)中直接运行:

pip install stable-baselines3

二、代码实战

在pycharm配置好编译器为虚拟环境py37,运行代码:

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import mujoco_py

# 创建并行环境
vec_env = make_vec_env("Hopper-v2", n_envs=4)

# 初始化模型
model = PPO("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=250000)
model.save("hopper_ppo")

# 删除模型以演示保存和加载
del model

# 重新加载模型
model = PPO.load("hopper_ppo")

obs = vec_env.reset()

while True:
    # 获取动作
    action, _states = model.predict(obs)

    # 环境步进
    obs, rewards, dones, infos = vec_env.step(action)

    # 渲染环境
    vec_env.env_method('render')

    # 打印奖励
    #print(f"Rewards: {rewards}")

    # 检查环境是否完成
    if dones.any():
        obs = vec_env.reset()

如果出现报错,可以将mujoco_env.py这个文件中的两个函数修改一下:

    def render(self, mode="human"):
        viewer = self._get_viewer(mode)

        if mode == "human":
            viewer.render()
        elif mode == "rgb_array":
            viewer.render()
            width, height = viewer.window.width, viewer.window.height
            data = viewer.read_pixels(width, height, depth=False)
            return data[::-1, :, :]  # 翻转图像
    def _get_viewer(
            self, mode
    ) -> Union["mujoco_py.MjViewer", None]:
        self.viewer = self._viewers.get(mode)

        if self.viewer is None:
            if mode == "human":
                try:
                    self.viewer = mujoco_py.MjViewer(self.sim)
                except Exception as e:
                    raise RuntimeError(f"Failed to create MjViewer: {str(e)}")

            elif mode in {"rgb_array"}:
                try:
                    self.viewer = mujoco_py.MjViewer(self.sim)
                except Exception as e:
                    raise RuntimeError(f"Failed to create MjViewer for offscreen rendering: {str(e)}")
            else:
                raise AttributeError(
                    f"Unknown mode: {mode}, expected modes: {self.metadata['render_modes']}"
                )

            self.viewer_setup()
            self._viewers[mode] = self.viewer

        return self.viewer

然后上面运行的代码结果如下:

cpu device
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 19       |
|    ep_rew_mean     | 15.8     |
| time/              |          |
|    fps             | 4304     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 8192     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 27.7        |
|    ep_rew_mean          | 28.3        |
| time/                   |             |
|    fps                  | 2357        |
|    iterations           | 2           |
|    time_elapsed         | 6           |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.012585539 |
|    clip_fraction        | 0.189       |
|    clip_range           | 0.2         |
|    entropy_loss         | -4.25       |
|    explained_variance   | -0.164      |
|    learning_rate        | 0.0003      |
|    loss                 | 10.9        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.019      |
|    std                  | 0.992       |
|    value_loss           | 19.2        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 39.6        |
|    ep_rew_mean          | 54.5        |
| time/                   |             |
|    fps                  | 2052        |
|    iterations           | 3           |
|    time_elapsed         | 11          |
|    total_timesteps      | 24576       |
| train/                  |             |
|    approx_kl            | 0.015236925 |
|    clip_fraction        | 0.216       |
|    clip_range           | 0.2         |
|    entropy_loss         | -4.21       |
|    explained_variance   | 0.401       |
|    learning_rate        | 0.0003      |
|    loss                 | 37.3        |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.0269     |
|    std                  | 0.978       |
|    value_loss           | 75.5        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 54.5        |
|    ep_rew_mean          | 89.9        |
| time/                   |             |
|    fps                  | 1928        |
|    iterations           | 4           |
|    time_elapsed         | 16          |
|    total_timesteps      | 32768       |
| train/                  |             |
|    approx_kl            | 0.013618716 |
|    clip_fraction        | 0.166       |
|    clip_range           | 0.2         |
|    entropy_loss         | -4.17       |
|    explained_variance   | 0.329       |
|    learning_rate        | 0.0003      |
|    loss                 | 87.4        |
|    n_updates            | 30          |
|    policy_gradient_loss | -0.0229     |
|    std                  | 0.966       |
|    value_loss           | 170         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 67.7        |
|    ep_rew_mean          | 132         |
| time/                   |             |
|    fps                  | 1848        |
|    iterations           | 5           |
|    time_elapsed         | 22          |
|    total_timesteps      | 40960       |
| train/                  |             |
|    approx_kl            | 0.010008034 |
|    clip_fraction        | 0.118       |
|    clip_range           | 0.2         |
|    entropy_loss         | -4.12       |
|    explained_variance   | 0.184       |
|    learning_rate        | 0.0003      |
|    loss                 | 89.7        |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0172     |
|    std                  | 0.951       |
|    value_loss           | 217         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 79.3        |
|    ep_rew_mean          | 161         |
| time/                   |             |
|    fps                  | 1809        |
|    iterations           | 6           |
|    time_elapsed         | 27          |
|    total_timesteps      | 49152       |
| train/                  |             |
|    approx_kl            | 0.008857504 |
|    clip_fraction        | 0.111       |
|    clip_range           | 0.2         |
|    entropy_loss         | -4.08       |
|    explained_variance   | 0.285       |
|    learning_rate        | 0.0003      |
|    loss                 | 114         |
|    n_updates            | 50          |
|    policy_gradient_loss | -0.0159     |
|    std                  | 0.94        |
|    value_loss           | 215         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 91.3        |
|    ep_rew_mean          | 192         |
| time/                   |             |
|    fps                  | 1784        |
|    iterations           | 7           |
|    time_elapsed         | 32          |
|    total_timesteps      | 57344       |
| train/                  |             |
|    approx_kl            | 0.009635251 |
|    clip_fraction        | 0.0983      |
|    clip_range           | 0.2         |
|    entropy_loss         | -4.04       |
|    explained_variance   | 0.489       |
|    learning_rate        | 0.0003      |
|    loss                 | 70.2        |
|    n_updates            | 60          |
|    policy_gradient_loss | -0.014      |
|    std                  | 0.928       |
|    value_loss           | 207         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 100         |
|    ep_rew_mean          | 217         |
| time/                   |             |
|    fps                  | 1763        |
|    iterations           | 8           |
|    time_elapsed         | 37          |
|    total_timesteps      | 65536       |
| train/                  |             |
|    approx_kl            | 0.010597339 |
|    clip_fraction        | 0.109       |
|    clip_range           | 0.2         |
|    entropy_loss         | -4          |
|    explained_variance   | 0.73        |
|    learning_rate        | 0.0003      |
|    loss                 | 85.7        |
|    n_updates            | 70          |
|    policy_gradient_loss | -0.0129     |
|    std                  | 0.916       |
|    value_loss           | 135         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 104         |
|    ep_rew_mean          | 231         |
| time/                   |             |
|    fps                  | 1748        |
|    iterations           | 9           |
|    time_elapsed         | 42          |
|    total_timesteps      | 73728       |
| train/                  |             |
|    approx_kl            | 0.010786813 |
|    clip_fraction        | 0.132       |
|    clip_range           | 0.2         |
|    entropy_loss         | -4          |
|    explained_variance   | 0.886       |
|    learning_rate        | 0.0003      |
|    loss                 | 7.66        |
|    n_updates            | 80          |
|    policy_gradient_loss | -0.0107     |
|    std                  | 0.921       |
|    value_loss           | 73.2        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 109         |
|    ep_rew_mean          | 249         |
| time/                   |             |
|    fps                  | 1738        |
|    iterations           | 10          |
|    time_elapsed         | 47          |
|    total_timesteps      | 81920       |
| train/                  |             |
|    approx_kl            | 0.010976644 |
|    clip_fraction        | 0.116       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.98       |
|    explained_variance   | 0.916       |
|    learning_rate        | 0.0003      |
|    loss                 | 58.3        |
|    n_updates            | 90          |
|    policy_gradient_loss | -0.0108     |
|    std                  | 0.908       |
|    value_loss           | 75.1        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 111         |
|    ep_rew_mean          | 258         |
| time/                   |             |
|    fps                  | 1731        |
|    iterations           | 11          |
|    time_elapsed         | 52          |
|    total_timesteps      | 90112       |
| train/                  |             |
|    approx_kl            | 0.011209324 |
|    clip_fraction        | 0.132       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.93       |
|    explained_variance   | 0.926       |
|    learning_rate        | 0.0003      |
|    loss                 | 21.2        |
|    n_updates            | 100         |
|    policy_gradient_loss | -0.013      |
|    std                  | 0.893       |
|    value_loss           | 72          |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 117         |
|    ep_rew_mean          | 278         |
| time/                   |             |
|    fps                  | 1725        |
|    iterations           | 12          |
|    time_elapsed         | 56          |
|    total_timesteps      | 98304       |
| train/                  |             |
|    approx_kl            | 0.009375557 |
|    clip_fraction        | 0.122       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.9        |
|    explained_variance   | 0.922       |
|    learning_rate        | 0.0003      |
|    loss                 | 25.2        |
|    n_updates            | 110         |
|    policy_gradient_loss | -0.0127     |
|    std                  | 0.887       |
|    value_loss           | 85.2        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 124         |
|    ep_rew_mean          | 298         |
| time/                   |             |
|    fps                  | 1719        |
|    iterations           | 13          |
|    time_elapsed         | 61          |
|    total_timesteps      | 106496      |
| train/                  |             |
|    approx_kl            | 0.010624651 |
|    clip_fraction        | 0.127       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.85       |
|    explained_variance   | 0.928       |
|    learning_rate        | 0.0003      |
|    loss                 | 34.1        |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.0105     |
|    std                  | 0.868       |
|    value_loss           | 72.7        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 126         |
|    ep_rew_mean          | 305         |
| time/                   |             |
|    fps                  | 1714        |
|    iterations           | 14          |
|    time_elapsed         | 66          |
|    total_timesteps      | 114688      |
| train/                  |             |
|    approx_kl            | 0.009903484 |
|    clip_fraction        | 0.106       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.82       |
|    explained_variance   | 0.941       |
|    learning_rate        | 0.0003      |
|    loss                 | 23.6        |
|    n_updates            | 130         |
|    policy_gradient_loss | -0.00767    |
|    std                  | 0.866       |
|    value_loss           | 66.4        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 131         |
|    ep_rew_mean          | 325         |
| time/                   |             |
|    fps                  | 1711        |
|    iterations           | 15          |
|    time_elapsed         | 71          |
|    total_timesteps      | 122880      |
| train/                  |             |
|    approx_kl            | 0.011343157 |
|    clip_fraction        | 0.132       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.8        |
|    explained_variance   | 0.941       |
|    learning_rate        | 0.0003      |
|    loss                 | 19.8        |
|    n_updates            | 140         |
|    policy_gradient_loss | -0.00909    |
|    std                  | 0.854       |
|    value_loss           | 77.6        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 139         |
|    ep_rew_mean          | 354         |
| time/                   |             |
|    fps                  | 1707        |
|    iterations           | 16          |
|    time_elapsed         | 76          |
|    total_timesteps      | 131072      |
| train/                  |             |
|    approx_kl            | 0.010334579 |
|    clip_fraction        | 0.106       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.75       |
|    explained_variance   | 0.927       |
|    learning_rate        | 0.0003      |
|    loss                 | 43          |
|    n_updates            | 150         |
|    policy_gradient_loss | -0.00583    |
|    std                  | 0.839       |
|    value_loss           | 97          |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 141         |
|    ep_rew_mean          | 366         |
| time/                   |             |
|    fps                  | 1704        |
|    iterations           | 17          |
|    time_elapsed         | 81          |
|    total_timesteps      | 139264      |
| train/                  |             |
|    approx_kl            | 0.008545404 |
|    clip_fraction        | 0.112       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.71       |
|    explained_variance   | 0.893       |
|    learning_rate        | 0.0003      |
|    loss                 | 99.5        |
|    n_updates            | 160         |
|    policy_gradient_loss | -0.00854    |
|    std                  | 0.83        |
|    value_loss           | 136         |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 143          |
|    ep_rew_mean          | 373          |
| time/                   |              |
|    fps                  | 1701         |
|    iterations           | 18           |
|    time_elapsed         | 86           |
|    total_timesteps      | 147456       |
| train/                  |              |
|    approx_kl            | 0.0077442103 |
|    clip_fraction        | 0.0886       |
|    clip_range           | 0.2          |
|    entropy_loss         | -3.68        |
|    explained_variance   | 0.936        |
|    learning_rate        | 0.0003       |
|    loss                 | 28.8         |
|    n_updates            | 170          |
|    policy_gradient_loss | -0.00499     |
|    std                  | 0.821        |
|    value_loss           | 95.3         |
------------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 146         |
|    ep_rew_mean          | 389         |
| time/                   |             |
|    fps                  | 1699        |
|    iterations           | 19          |
|    time_elapsed         | 91          |
|    total_timesteps      | 155648      |
| train/                  |             |
|    approx_kl            | 0.011910671 |
|    clip_fraction        | 0.128       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.63       |
|    explained_variance   | 0.969       |
|    learning_rate        | 0.0003      |
|    loss                 | 20.6        |
|    n_updates            | 180         |
|    policy_gradient_loss | -0.00725    |
|    std                  | 0.809       |
|    value_loss           | 54.4        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 157         |
|    ep_rew_mean          | 425         |
| time/                   |             |
|    fps                  | 1697        |
|    iterations           | 20          |
|    time_elapsed         | 96          |
|    total_timesteps      | 163840      |
| train/                  |             |
|    approx_kl            | 0.011053506 |
|    clip_fraction        | 0.127       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.59       |
|    explained_variance   | 0.912       |
|    learning_rate        | 0.0003      |
|    loss                 | 91.5        |
|    n_updates            | 190         |
|    policy_gradient_loss | -0.00768    |
|    std                  | 0.795       |
|    value_loss           | 150         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 171         |
|    ep_rew_mean          | 480         |
| time/                   |             |
|    fps                  | 1695        |
|    iterations           | 21          |
|    time_elapsed         | 101         |
|    total_timesteps      | 172032      |
| train/                  |             |
|    approx_kl            | 0.010869894 |
|    clip_fraction        | 0.126       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.55       |
|    explained_variance   | 0.932       |
|    learning_rate        | 0.0003      |
|    loss                 | 28.1        |
|    n_updates            | 200         |
|    policy_gradient_loss | -0.00736    |
|    std                  | 0.789       |
|    value_loss           | 127         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 182         |
|    ep_rew_mean          | 520         |
| time/                   |             |
|    fps                  | 1690        |
|    iterations           | 22          |
|    time_elapsed         | 106         |
|    total_timesteps      | 180224      |
| train/                  |             |
|    approx_kl            | 0.010209182 |
|    clip_fraction        | 0.124       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.51       |
|    explained_variance   | 0.967       |
|    learning_rate        | 0.0003      |
|    loss                 | 23.5        |
|    n_updates            | 210         |
|    policy_gradient_loss | -0.00598    |
|    std                  | 0.776       |
|    value_loss           | 64.3        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 193         |
|    ep_rew_mean          | 558         |
| time/                   |             |
|    fps                  | 1684        |
|    iterations           | 23          |
|    time_elapsed         | 111         |
|    total_timesteps      | 188416      |
| train/                  |             |
|    approx_kl            | 0.009160686 |
|    clip_fraction        | 0.116       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.46       |
|    explained_variance   | 0.943       |
|    learning_rate        | 0.0003      |
|    loss                 | 35          |
|    n_updates            | 220         |
|    policy_gradient_loss | -0.00495    |
|    std                  | 0.766       |
|    value_loss           | 124         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 208         |
|    ep_rew_mean          | 603         |
| time/                   |             |
|    fps                  | 1679        |
|    iterations           | 24          |
|    time_elapsed         | 117         |
|    total_timesteps      | 196608      |
| train/                  |             |
|    approx_kl            | 0.010316408 |
|    clip_fraction        | 0.139       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.42       |
|    explained_variance   | 0.944       |
|    learning_rate        | 0.0003      |
|    loss                 | 17.1        |
|    n_updates            | 230         |
|    policy_gradient_loss | -0.00458    |
|    std                  | 0.757       |
|    value_loss           | 141         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 225         |
|    ep_rew_mean          | 663         |
| time/                   |             |
|    fps                  | 1675        |
|    iterations           | 25          |
|    time_elapsed         | 122         |
|    total_timesteps      | 204800      |
| train/                  |             |
|    approx_kl            | 0.010222745 |
|    clip_fraction        | 0.125       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.4        |
|    explained_variance   | 0.95        |
|    learning_rate        | 0.0003      |
|    loss                 | 14.6        |
|    n_updates            | 240         |
|    policy_gradient_loss | -0.00486    |
|    std                  | 0.75        |
|    value_loss           | 128         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 238         |
|    ep_rew_mean          | 713         |
| time/                   |             |
|    fps                  | 1670        |
|    iterations           | 26          |
|    time_elapsed         | 127         |
|    total_timesteps      | 212992      |
| train/                  |             |
|    approx_kl            | 0.010443018 |
|    clip_fraction        | 0.133       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.36       |
|    explained_variance   | 0.982       |
|    learning_rate        | 0.0003      |
|    loss                 | 13.2        |
|    n_updates            | 250         |
|    policy_gradient_loss | -0.00538    |
|    std                  | 0.741       |
|    value_loss           | 44.8        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 252         |
|    ep_rew_mean          | 758         |
| time/                   |             |
|    fps                  | 1665        |
|    iterations           | 27          |
|    time_elapsed         | 132         |
|    total_timesteps      | 221184      |
| train/                  |             |
|    approx_kl            | 0.010160094 |
|    clip_fraction        | 0.116       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.33       |
|    explained_variance   | 0.9         |
|    learning_rate        | 0.0003      |
|    loss                 | 315         |
|    n_updates            | 260         |
|    policy_gradient_loss | -0.00452    |
|    std                  | 0.734       |
|    value_loss           | 297         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 254         |
|    ep_rew_mean          | 773         |
| time/                   |             |
|    fps                  | 1660        |
|    iterations           | 28          |
|    time_elapsed         | 138         |
|    total_timesteps      | 229376      |
| train/                  |             |
|    approx_kl            | 0.011000248 |
|    clip_fraction        | 0.143       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.29       |
|    explained_variance   | 0.966       |
|    learning_rate        | 0.0003      |
|    loss                 | 8.94        |
|    n_updates            | 270         |
|    policy_gradient_loss | -0.00638    |
|    std                  | 0.725       |
|    value_loss           | 98.7        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 254         |
|    ep_rew_mean          | 785         |
| time/                   |             |
|    fps                  | 1653        |
|    iterations           | 29          |
|    time_elapsed         | 143         |
|    total_timesteps      | 237568      |
| train/                  |             |
|    approx_kl            | 0.010932952 |
|    clip_fraction        | 0.132       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.25       |
|    explained_variance   | 0.986       |
|    learning_rate        | 0.0003      |
|    loss                 | 21.1        |
|    n_updates            | 280         |
|    policy_gradient_loss | -0.00706    |
|    std                  | 0.715       |
|    value_loss           | 39.4        |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 260         |
|    ep_rew_mean          | 813         |
| time/                   |             |
|    fps                  | 1649        |
|    iterations           | 30          |
|    time_elapsed         | 149         |
|    total_timesteps      | 245760      |
| train/                  |             |
|    approx_kl            | 0.011197671 |
|    clip_fraction        | 0.128       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.23       |
|    explained_variance   | 0.969       |
|    learning_rate        | 0.0003      |
|    loss                 | 13.4        |
|    n_updates            | 290         |
|    policy_gradient_loss | -0.00404    |
|    std                  | 0.713       |
|    value_loss           | 106         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 265         |
|    ep_rew_mean          | 832         |
| time/                   |             |
|    fps                  | 1644        |
|    iterations           | 31          |
|    time_elapsed         | 154         |
|    total_timesteps      | 253952      |
| train/                  |             |
|    approx_kl            | 0.012921608 |
|    clip_fraction        | 0.128       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.19       |
|    explained_variance   | 0.955       |
|    learning_rate        | 0.0003      |
|    loss                 | 258         |
|    n_updates            | 300         |
|    policy_gradient_loss | -0.00482    |
|    std                  | 0.699       |
|    value_loss           | 164         |
-----------------------------------------
Creating window glfw

最后会有四个渲染的hopper的动图。

三、PPO原理

近端策略优化算法(proximal policy optimization,PPO

PPO 算法之所以被提出,根本原因在于 Policy Gradient 在处理连续动作空间时 Learning rate 取值抉择困难。Learning rate 取值过小,就会导致深度强化学习收敛性较差,陷入完不成训练的局面,取值过大则导致新旧策略迭代时数据不一致,造成学习波动较大或局部震荡。除此之外,Policy Gradient 因为在线学习的性质,进行迭代策略时原先的采样数据无法被重复利用,每次迭代都需要重新采样;

同样地置信域策略梯度算法(Trust Region Policy Optimization,TRPO)虽然利用重要性采样(Important-sampling)、共轭梯度法求解提升了样本效率、训练速率等,但在处理函数的二阶近似时会面临计算量过大,以及实现过程复杂、兼容性差等缺陷。 PPO 算法具备 Policy Gradient、TRPO 的部分优点采样数据和使用随机梯度上升方法优化代替目标函数之间交替进行,虽然标准的策略梯度方法对每个数据样本执行一次梯度更新,但 PPO 提出新目标函数,可以实现小批量更新。

核心原理:

该算法在迭代更新时,观察当前策略在 t 时刻智能体处于状态 s 所采取的行为概率\pi (\alpha _{t}|s_{t}),之前策略所采取行为概率\pi _{\theta old}(a_{t}|s_{t}),计算概率的比值来控制新策略更新幅度,比值r_{t}:

r_{t}(\theta )=\frac{\pi_{\theta }(a_{t}|s_{t})}{\pi_{\theta old}(a_{t}|s_{t})}

新旧策略差异明显且优势函数较大,则适当增加更新幅度;若r_{t}比值越接近1,标明新旧策略差异越小。

优势函数代表,在状态 s 下,行为 a 相对于均值的偏差。在论文中,优势函数使用 GAE(generalized advantage estimation)来计算:

PPO 算法可依据 Actor 网络的更新方式细化为含有自适应 KL-散度(KL Penalty)的 PPO-Penalty 含有 Clippped Surrogate Objective 函数的 PPO-Clip

如果优势函数为正数,需要增大新旧策略比值r_{t} ,然而当r_{t}>1+\varepsilon时,将不提供额外的激励;如果优势函数是负数,需要减少新旧策略比值 r_{t},但在 r_{t}<1+\varepsilon时,不提供额外的激励,这使得新旧策略的差异被限制在合理范围内。PPO 本质上基于 Actor-Critic 框架,算法流程如下:

PPO 算法主要由 Actor 和 Critic 两部分构成,Critic 部分更新方式与其他Actor-Critic 类型相似,通常采用计算 TD  error(时序差分误差)形式。对于 Actor 的更新方式,PPO 可在KLPENL 、CLIPL 之间选择对于当前实验环境稳定性适用性更强的目标函数,经过 OpenAI 研究团队实验论证,PPO- Clip 比 PPO- Penalty有更好的数据效率和可行性。 

四、代码结果分析

stable baseline3中的ppo算法主要如下:

import warnings
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union

import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
import matplotlib.pyplot as plt
SelfPPO = TypeVar("SelfPPO", bound="PPO")


class PPO(OnPolicyAlgorithm):
    """
    Proximal Policy Optimization algorithm (PPO) (clip version)

    Paper: https://arxiv.org/abs/1707.06347
    Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
    https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
    Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)

    Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param n_steps: The number of steps to run for each environment per update
        (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
        NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
        See https://github.com/pytorch/pytorch/issues/29372
    :param batch_size: Minibatch size
    :param n_epochs: Number of epoch when optimizing the surrogate loss
    :param gamma: Discount factor
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
    :param clip_range: Clipping parameter, it can be a function of the current progress
        remaining (from 1 to 0).
    :param clip_range_vf: Clipping parameter for the value function,
        it can be a function of the current progress remaining (from 1 to 0).
        This is a parameter specific to the OpenAI implementation. If None is passed (default),
        no clipping will be done on the value function.
        IMPORTANT: this clipping depends on the reward scaling.
    :param normalize_advantage: Whether to normalize or not the advantage
    :param ent_coef: Entropy coefficient for the loss calculation
    :param vf_coef: Value function coefficient for the loss calculation
    :param max_grad_norm: The maximum value for the gradient clipping
    :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
        instead of action noise exploration (default: False)
    :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
        Default: -1 (only sample at the beginning of the rollout)
    :param target_kl: Limit the KL divergence between updates,
        because the clipping is not enough to prevent large update
        see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
        By default, there is no limit on the kl div.
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    """

    policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
        "MlpPolicy": ActorCriticPolicy,
        "CnnPolicy": ActorCriticCnnPolicy,
        "MultiInputPolicy": MultiInputActorCriticPolicy,
    }

    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 3e-4,
        n_steps: int = 2048,
        batch_size: int = 64,
        n_epochs: int = 10,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_range: Union[float, Schedule] = 0.2,
        clip_range_vf: Union[None, float, Schedule] = None,
        normalize_advantage: bool = True,
        ent_coef: float = 0.0,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        target_kl: Optional[float] = None,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
    ):
        super().__init__(
            policy,
            env,
            learning_rate=learning_rate,
            n_steps=n_steps,
            gamma=gamma,
            gae_lambda=gae_lambda,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            use_sde=use_sde,
            sde_sample_freq=sde_sample_freq,
            stats_window_size=stats_window_size,
            tensorboard_log=tensorboard_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            device=device,
            seed=seed,
            _init_setup_model=False,
            supported_action_spaces=(
                spaces.Box,
                spaces.Discrete,
                spaces.MultiDiscrete,
                spaces.MultiBinary,
            ),
        )

        # Sanity check, otherwise it will lead to noisy gradient and NaN
        # because of the advantage normalization
        if normalize_advantage:
            assert (
                batch_size > 1
            ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"

        if self.env is not None:
            # Check that `n_steps * n_envs > 1` to avoid NaN
            # when doing advantage normalization
            buffer_size = self.env.num_envs * self.n_steps
            assert buffer_size > 1 or (
                not normalize_advantage
            ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
            # Check that the rollout buffer size is a multiple of the mini-batch size
            untruncated_batches = buffer_size // batch_size
            if buffer_size % batch_size > 0:
                warnings.warn(
                    f"You have specified a mini-batch size of {batch_size},"
                    f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
                    f" after every {untruncated_batches} untruncated mini-batches,"
                    f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
                    f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
                    f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
                )
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.clip_range = clip_range
        self.clip_range_vf = clip_range_vf
        self.normalize_advantage = normalize_advantage
        self.target_kl = target_kl

        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        super()._setup_model()

        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)  # type: ignore[operator]
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)  # type: ignore[operator]

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True
        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
                values = values.flatten()
                # Normalize advantage
                advantages = rollout_data.advantages
                # Normalization does not make sense if mini batchsize == 1, see GH issue #325
                if self.normalize_advantage and len(advantages) > 1:
                    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob)

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None:
                    # No clipping
                    values_pred = values
                else:
                    # Clip the difference between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred)
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else:
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.policy.optimizer.step()

            self._n_updates += 1
            if not continue_training:
                break

        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

    def learn(
        self: SelfPPO,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 1,
        tb_log_name: str = "PPO",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SelfPPO:
        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            tb_log_name=tb_log_name,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )

参考链接:

【深度强化学习】(6) PPO 模型解析,附Pytorch完整代码_ppo模型-CSDN博客

  • 13
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值