stable-baselines3学习之Tensorboard

stable-baselines3学习之Tensorboard系列

1.基本用法

要使用stable-baselines3的 Tensorboard,您只需将日志文件夹的位置传递给 RL 的agent:

from stable_baselines3 import A2C

model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000)

您还可以在训练时定义自定义日志名称(默认为算法名称)

from stable_baselines3 import A2C

model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000, tb_log_name="first_run")
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
# By default, it will create a new curve
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False)
model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)

调用 learn 函数后,您可以使用以下 bash 命令在训练期间或之后监控 RL agent:

tensorboard --logdir ./a2c_cartpole_tensorboard/

注:要在该项目文件路径下运行这条命令

比如:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WpecFpF9-1647931462737)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322120347758.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M7dMwEEI-1647931462738)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322120409523.png)]

2.Logging More Values

使用callback可以容易的记录更多日志用Tensorboard,这里有一个简单的例子去记录额外的tensor和任意的scalar值:

import numpy as np

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback

model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="./tmp/sac0/", verbose=1)


class TensorboardCallback(BaseCallback):
    """
    Custom callback for plotting additional values in tensorboard.
    """

    def __init__(self, verbose=0):
        super(TensorboardCallback, self).__init__(verbose)

    def _on_step(self) -> bool:
        # Log scalar value (here a random variable)
        value = np.random.random()
        self.logger.record('random_value', value)
        return True


model.learn(50000, callback=TensorboardCallback())
tensorboard --logdir ./tmp/sac0/

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pIq7VA1T-1647931462738)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322124527536.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-utw3W67U-1647931462739)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322124553448.png)]

3.Logging Images

TensorBoard 支持定期记录图像数据,这有助于在训练期间的各个阶段评估agent。

以下是如何定期将图像渲染到 TensorBoard 的示例:

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Image

model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="./tmp/sac1/", verbose=1)


class ImageRecorderCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(ImageRecorderCallback, self).__init__(verbose)

    def _on_step(self):
        image = self.training_env.render(mode="rgb_array")
        # "HWC" specify the dataformat of the image, here channel last
        # (H for height, W for width, C for channel)
        # See https://pytorch.org/docs/stable/tensorboard.html
        # for supported formats
        self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
        return True


model.learn(50000, callback=ImageRecorderCallback())

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ocQxw7ue-1647931462739)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322130204745.png)]

tensorboard --logdir ./tmp/sac1/

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eAsoEg4Q-1647931462740)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322131447492.png)]

4.Logging Figures/Plots

TensorBoard 支持定期记录使用 matplotlib 创建的图形/绘图,这有助于在训练期间评估各个阶段的agent。

以下是如何在 TensorBoard 中定期存储绘图的示例:

import numpy as np
import matplotlib.pyplot as plt

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Figure

model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="./tmp/sac2/", verbose=1)


class FigureRecorderCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(FigureRecorderCallback, self).__init__(verbose)

    def _on_step(self):
        # Plot values (here a random variable)
        figure = plt.figure()
        figure.add_subplot().plot(np.random.random(3))
        # Close the figure after logging it
        self.logger.record("trajectory/figure", Figure(figure, close=True), exclude=("stdout", "log", "json", "csv"))
        plt.close()
        return True


model.learn(50000, callback=FigureRecorderCallback())
tensorboard --logdir ./tmp/sac1/

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IGphPL0B-1647931462740)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322134756257.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9rxhe5yd-1647931462741)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322134906652.png)]

5.Logging Videos

TensorBoard 支持定期记录视频数据,这有助于在训练期间评估各个阶段的agent。

以下是如何显示一个episode并将生成的视频定期记录到 TensorBoard 的示例:

注:需安装moviepy

from typing import Any, Dict

import gym
import torch as th

from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import Video


class VideoRecorderCallback(BaseCallback):
    def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):
        """
        Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard

        :param eval_env: A gym environment from which the trajectory is recorded
        :param render_freq: Render the agent's trajectory every eval_freq call of the callback.
        :param n_eval_episodes: Number of episodes to render
        :param deterministic: Whether to use deterministic or stochastic policy
        """
        super().__init__()
        self._eval_env = eval_env
        self._render_freq = render_freq
        self._n_eval_episodes = n_eval_episodes
        self._deterministic = deterministic

    def _on_step(self) -> bool:
        if self.n_calls % self._render_freq == 0:
            screens = []

            def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
                """
                Renders the environment in its current state, recording the screen in the captured `screens` list

                :param _locals: A dictionary containing all local variables of the callback's scope
                :param _globals: A dictionary containing all global variables of the callback's scope
                """
                screen = self._eval_env.render(mode="rgb_array")
                # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
                screens.append(screen.transpose(2, 0, 1))

            evaluate_policy(
                self.model,
                self._eval_env,
                callback=grab_screens,
                n_eval_episodes=self._n_eval_episodes,
                deterministic=self._deterministic,
            )
            self.logger.record(
                "trajectory/video",
                Video(th.ByteTensor([screens]), fps=40),
                exclude=("stdout", "log", "json", "csv"),
            )
        return True


model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="./tmp/runs/", verbose=1)
video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000)
model.learn(total_timesteps=int(5e4), callback=video_recorder)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YficDhFG-1647931462741)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220322142640752.png)]

tensorboard --logdir ./tmp/runs/

在这里插入图片描述

  • 2
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
### 回答1: 好的,下面是使用stable-baselines3搭建ppo算法的步骤: 1. 安装stable-baselines3:可以使用pip命令进行安装:`pip install stable-baselines3` 2. 导入相关库:`import gym`,`from stable_baselines3 import PPO` 3. 创建环境:`env = gym.make('环境名称')` 4. 定义PPO模型:`model = PPO('MlpPolicy', env, verbose=1)` 其中,'MlpPolicy'是指使用多层感知器作为策略网络,也可以使用'CnnPolicy'使用卷积神经网络作为策略网络。 5. 训练模型:`model.learn(total_timesteps=10000)` 其中,total_timesteps是训练模型的总步数。 6. 保存模型:`model.save('模型名称')` 可以将模型保存在本地,以便之后使用。 7. 加载模型:`model = PPO.load('模型名称')` 可以从本地加载模型,以便之后进行测试或使用。 8. 测试模型:`obs = env.reset()`,`for i in range(1000):`,` action, _states = model.predict(obs)`,` obs, rewards, dones, info = env.step(action)`,` env.render()` 其中,obs是环境的初始状态,model.predict(obs)是使用模型预测下一步的动作,env.step(action)是执行动作并返回下一步的状态、奖励、是否结束等信息,env.render()是将环境渲染出来以便观察。 以上就是使用stable-baselines3搭建ppo算法的步骤,希望能对你有所帮助。 ### 回答2: Stable Baselines3是一个用于强化学习的Python库,它提供了多种强化学习算法的实现,包括PPO算法。下面是使用Stable Baselines3搭建PPO算法的步骤: 1. 安装Stable Baselines3 首先,需要在Python环境中安装Stable Baselines3库。可以通过pip命令进行安装:`pip install stable-baselines3` 2. 定义环境 在使用PPO算法之前,需要定义一个强化学习环境。这个环境可以是OpenAI Gym中的现有环境,也可以是自定义的环境。确保环境具备与PPO算法兼容的状态和动作空间。 3. 创建PPO模型 使用Stable Baselines3中的`PPO`类创建一个PPO模型对象。需要指定环境和其他参数,例如神经网络结构和学习率等。 ``` from stable_baselines3 import PPO model = PPO("MlpPolicy", env, verbose=1) ``` 4. 训练模型 使用创建的PPO模型对象对环境进行模型训练。可以指定训练的轮数(epochs)和每轮的步数(steps),以及其他训练参数。 ``` model.learn(total_timesteps=10000) ``` 5. 使用模型进行预测 在训练完成后,可以使用训练好的模型对新的状态进行预测。通过调用模型的predict方法,给定当前的状态,模型会输出一个动作。 ``` action = model.predict(observation) ``` 以上就是使用Stable Baselines3搭建PPO算法的基本步骤。根据具体的应用场景,还可以对训练过程和模型进行更多的调优和优化。 ### 回答3: stable-baselines3是一个Python库,可以用于搭建PPO(Proximal Policy Optimization)算法。PPO是一种强化学习算法,用于训练策略(policy)函数,以在强化学习任务中找到最优的策略。 首先,我们需要安装stable-baselines3库。可以通过在命令行中运行`pip install stable-baselines3`来完成安装。 然后,我们通过导入所需的模块来开始构建PPO算法。例如,我们可以导入`PPO`类,并创建一个模型对象。可以在创建模型对象时指定所需的超参数,例如神经网络的结构和学习率。 接下来,我们需要定义我们的环境。stable-baselines3库支持与OpenAI Gym兼容的环境。可以通过导入`gym`模块来创建环境对象,并将其传递给模型对象。 一旦有了模型和环境,我们就可以开始训练了。可以使用模型对象的`learn()`方法来执行训练。该方法需要指定训练的时间步数或迭代次数,以及其他训练相关的超参数。 一般来说,在训练过程中,我们可以选择保存模型的检查点,以便以后使用。stable-baselines3提供了保存和加载模型的功能,可以使用模型对象的`save()`和`load()`方法来完成。 一旦模型训练完成,我们可以使用训练好的策略函数来测试和评估模型的性能。可以使用模型对象的`predict()`方法来获取模型在给定状态下的动作。 总结来说,使用stable-baselines3搭建PPO算法的步骤包括安装库、创建模型对象、定义环境、执行训练和保存模型、使用训练好的模型进行测试和评估。这些步骤可以帮助我们构建一个基于PPO算法的强化学习模型

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值