stable-baselines3学习之Callbacks

callback是一系列函数在训练程序中会被called在给定阶段,您可以在训练期间使用callback来访问 RL 模型的内部状态。它允许人们进行监控,自动保存,模型操控,进度条等

1. 自定义callback

自定义一个callback需要继承自BaseCallback,这将使您能够访问事件 ( _on_training_start, _on_step) 和有用的变量(例如RL 模型的self.model

from stable_baselines3.common.callbacks import BaseCallback


class CustomCallback(BaseCallback):
    """
    A custom callback that derives from ``BaseCallback``.

    :param verbose: (int) Verbosity level 0: not output 1: info 2: debug
    """
    def __init__(self, verbose=0):
        super(CustomCallback, self).__init__(verbose)
        # Those variables will be accessible in the callback
        # (they are defined in the base class)
        # The RL model
        # self.model = None  # type: BaseAlgorithm
        # An alias for self.model.get_env(), the environment used for training
        # self.training_env = None  # type: Union[gym.Env, VecEnv, None]
        # Number of time the callback was called
        # self.n_calls = 0  # type: int
        # self.num_timesteps = 0  # type: int
        # local and global variables
        # self.locals = None  # type: Dict[str, Any]
        # self.globals = None  # type: Dict[str, Any]
        # The logger object, used to report things in the terminal
        # self.logger = None  # stable_baselines3.common.logger
        # # Sometimes, for event callback, it is useful
        # # to have access to the parent object
        # self.parent = None  # type: Optional[BaseCallback]

    def _on_training_start(self) -> None:
        """
        This method is called before the first rollout starts.
        第一个rollout开始之前这个方法被call
        """
        pass

    def _on_rollout_start(self) -> None:
        """
        A rollout is the collection of environment interaction
        using the current policy.
        This event is triggered before collecting new samples.
        新的sample被收集时该方法被触发
        """
        pass

    def _on_step(self) -> bool:
        """
        This method will be called by the model after each call to `env.step()`.
        这个方法被model call在每次进行env.step()

        For child callback (of an `EventCallback`), this will be called
        when the event is triggered.

        :return: (bool) If the callback returns False, training is aborted early.
        """
        return True

    def _on_rollout_end(self) -> None:
        """
        This event is triggered before updating the policy.
        更新policy前该方法被触发
        """
        pass

    def _on_training_end(self) -> None:
        """
        This event is triggered before exiting the `learn()` method.
        离开model.learn()之前被触发
        """
        pass
2. Event Callback

Stable Baselines provides a second type of BaseCallback, named EventCallback that is meant to trigger events. When an event is triggered, then a child callback is called.

SB提供了第二个BaseCallback叫做EventCallback来接受事件触发,当一个事件被触发时,一个child callback被call

As an example, EvalCallback is an EventCallback that will trigger its child callback when there is a new best model. A child callback is for instance StopTrainingOnRewardThreshold that stops the training if the mean reward achieved by the RL model is above a threshold.

举一个例子,Evalcallback是一个EventCallback,当一个新的最好的model出现时,它将触发它的child callback,比如StopTrainingOnRewardThreshold(它将使训练停止如果RL model的表现的平均奖励达到阈值)

class EventCallback(BaseCallback):
    """
    Base class for triggering callback on event.

    :param callback: (Optional[BaseCallback]) Callback that will be called
        when an event is triggered.
    :param verbose: (int)
    """
    def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
        super(EventCallback, self).__init__(verbose=verbose)
        self.callback = callback
        # Give access to the parent
        if callback is not None:
            self.callback.parent = self
    ...

    def _on_event(self) -> bool:
        if self.callback is not None:
            return self.callback()
        return True
3. Callback Collection

Stable Baselines provides you with a set of common callbacks for:

3.1CheckpointCallback

Callback for saving a model every save_freq calls to env.step(), you must specify a log folder (save_path) and optionally a prefix for the checkpoints (rl_model by default).

save_freq次进行 env.step()保存一下模型,你必须指定一个文件目录 (save_path) 并且可以选择一个前缀(默认用rl_model )

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CheckpointCallback
# Save a checkpoint every 1000 steps
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/',
                                         name_prefix='rl_model')

model = SAC('MlpPolicy', 'Pendulum-v0')
model.learn(2000, callback=checkpoint_callback)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bsA0athS-1648042686556)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220323200339940.png)]

3.2EvalCallback

Evaluate periodically the performance of an agent, using a separate test environment. It will save the best model if best_model_save_path folder is specified and save the evaluations results in a numpy archive (evaluations.npz) if log_path folder is specified.

用一个分离的测试环境周期性的评估agent一段时间的表现,如果best_model_save_path路径被指定它将保存最好的模型并且保存评估结果在(evaluations.npz) 中如果log_path路径被指定。

import gym

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback

# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
# Use deterministic actions for evaluation
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',
                             log_path='./logs/', eval_freq=500,
                             deterministic=True, render=False)

model = SAC('MlpPolicy', 'Pendulum-v0')
model.learn(5000, callback=eval_callback)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-a61UT9Y5-1648042686557)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220323202615670.png)]

3.3CallbackList

Class for chaining callbacks, they will be called sequentially. Alternatively或者, you can pass directly a list of callbacks to the learn() method, it will be converted automatically to a CallbackList.

import gym

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback

checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs3/')
# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs3/best_model',
                             log_path='./logs3/results', eval_freq=500)
# Create the callback list
callback = CallbackList([checkpoint_callback, eval_callback])

model = SAC('MlpPolicy', 'Pendulum-v0')
# Equivalent to:
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
model.learn(5000, callback=callback)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-q0lFBCec-1648042686558)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220323203657796.png)]

3.4StopTrainingOnRewardThreshold

该部分可结合 2. Event Callback学习

Stop the training once a threshold in episodic reward (mean episode reward over the evaluations) has been reached (i.e., when the model is good enough). It must be used with the EvalCallback and use the event triggered by a new best model.

import gym

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
# Stop training when the model reaches the reward threshold
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)

model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the reward threshold is reached
model.learn(int(1e10), callback=eval_callback)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3nleC2D8-1648042686559)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220323205702761.png)]

3.5EveryNTimesteps

An Event Callback that will trigger its child callback every n_steps timesteps.

import gym

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps

# this is equivalent to defining CheckpointCallback(save_freq=500)
# checkpoint_callback will be triggered every 500 steps
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs4/')
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)

model = PPO('MlpPolicy', 'Pendulum-v0', verbose=1)

model.learn(int(2e4), callback=event_callback)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wEKmyxwb-1648042686560)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220323210401337.png)]

3.6StopTrainingOnMaxEpisodes

Stop the training upon reaching the maximum number of episodes, regardless of the model’s total_timesteps value. Also, presumes that, for multiple environments, the desired behavior is that the agent trains on each env for max_episodes and in total for max_episodes * n_envs episodes.

from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes

# Stops training when the model reaches the maximum number of episodes
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)

model = A2C('MlpPolicy', 'Pendulum-v0', verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the max number of episodes is reached
model.learn(int(1e10), callback=callback_max_episodes)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9m70oqob-1648042686560)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220323212606980.png)]

3.7StopTrainingOnNoModelImprovement

Stop the training if there is no new best model (no new best mean reward) after more than a specific number of consecutive连续的 evaluations. The idea is to save time in experiments when you know that the learning curves are somehow well behaved and, therefore, after many evaluations without improvement the learning has probably stabilized. It must be used with the EvalCallback and use the event triggered after every evaluation.

如果在超过特定数量的连续评估后没有新的最佳模型(没有新的最佳平均奖励),则停止训练。当你知道学习曲线在某种程度上表现良好时,这个想法是为了节省实验时间,因此,经过多次评估而没有改进,学习可能已经稳定下来。它必须与EvalCallback一起使用,并使用每次评估后触发的事件。

import gym

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement

# Separate evaluation env
eval_env = gym.make("Pendulum-v0")
# Stop training if there is no improvement after more than 3 evaluations
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=3, min_evals=5, verbose=1)
eval_callback = EvalCallback(eval_env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1)

model = SAC("MlpPolicy", "Pendulum-v0", learning_rate=1e-3, verbose=1)
# Almost infinite number of timesteps, but the training will stop early
# as soon as the the number of consecutive evaluations without model
# improvement is greater than 3
model.learn(int(1e10), callback=eval_callback)

这个我看callcbacks里没有这个类,可能是我用的stb3版本里没有实现这个类

在这里插入图片描述

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
### 回答1: 好的,下面是使用stable-baselines3搭建ppo算法的步骤: 1. 安装stable-baselines3:可以使用pip命令进行安装:`pip install stable-baselines3` 2. 导入相关库:`import gym`,`from stable_baselines3 import PPO` 3. 创建环境:`env = gym.make('环境名称')` 4. 定义PPO模型:`model = PPO('MlpPolicy', env, verbose=1)` 其中,'MlpPolicy'是指使用多层感知器作为策略网络,也可以使用'CnnPolicy'使用卷积神经网络作为策略网络。 5. 训练模型:`model.learn(total_timesteps=10000)` 其中,total_timesteps是训练模型的总步数。 6. 保存模型:`model.save('模型名称')` 可以将模型保存在本地,以便之后使用。 7. 加载模型:`model = PPO.load('模型名称')` 可以从本地加载模型,以便之后进行测试或使用。 8. 测试模型:`obs = env.reset()`,`for i in range(1000):`,` action, _states = model.predict(obs)`,` obs, rewards, dones, info = env.step(action)`,` env.render()` 其中,obs是环境的初始状态,model.predict(obs)是使用模型预测下一步的动作,env.step(action)是执行动作并返回下一步的状态、奖励、是否结束等信息,env.render()是将环境渲染出来以便观察。 以上就是使用stable-baselines3搭建ppo算法的步骤,希望能对你有所帮助。 ### 回答2: Stable Baselines3是一个用于强化学习的Python库,它提供了多种强化学习算法的实现,包括PPO算法。下面是使用Stable Baselines3搭建PPO算法的步骤: 1. 安装Stable Baselines3 首先,需要在Python环境中安装Stable Baselines3库。可以通过pip命令进行安装:`pip install stable-baselines3` 2. 定义环境 在使用PPO算法之前,需要定义一个强化学习环境。这个环境可以是OpenAI Gym中的现有环境,也可以是自定义的环境。确保环境具备与PPO算法兼容的状态和动作空间。 3. 创建PPO模型 使用Stable Baselines3中的`PPO`类创建一个PPO模型对象。需要指定环境和其他参数,例如神经网络结构和学习率等。 ``` from stable_baselines3 import PPO model = PPO("MlpPolicy", env, verbose=1) ``` 4. 训练模型 使用创建的PPO模型对象对环境进行模型训练。可以指定训练的轮数(epochs)和每轮的步数(steps),以及其他训练参数。 ``` model.learn(total_timesteps=10000) ``` 5. 使用模型进行预测 在训练完成后,可以使用训练好的模型对新的状态进行预测。通过调用模型的predict方法,给定当前的状态,模型会输出一个动作。 ``` action = model.predict(observation) ``` 以上就是使用Stable Baselines3搭建PPO算法的基本步骤。根据具体的应用场景,还可以对训练过程和模型进行更多的调优和优化。 ### 回答3: stable-baselines3是一个Python库,可以用于搭建PPO(Proximal Policy Optimization)算法。PPO是一种强化学习算法,用于训练策略(policy)函数,以在强化学习任务中找到最优的策略。 首先,我们需要安装stable-baselines3库。可以通过在命令行中运行`pip install stable-baselines3`来完成安装。 然后,我们通过导入所需的模块来开始构建PPO算法。例如,我们可以导入`PPO`类,并创建一个模型对象。可以在创建模型对象时指定所需的超参数,例如神经网络的结构和学习率。 接下来,我们需要定义我们的环境。stable-baselines3库支持与OpenAI Gym兼容的环境。可以通过导入`gym`模块来创建环境对象,并将其传递给模型对象。 一旦有了模型和环境,我们就可以开始训练了。可以使用模型对象的`learn()`方法来执行训练。该方法需要指定训练的时间步数或迭代次数,以及其他训练相关的超参数。 一般来说,在训练过程中,我们可以选择保存模型的检查点,以便以后使用。stable-baselines3提供了保存和加载模型的功能,可以使用模型对象的`save()`和`load()`方法来完成。 一旦模型训练完成,我们可以使用训练好的策略函数来测试和评估模型的性能。可以使用模型对象的`predict()`方法来获取模型在给定状态下的动作。 总结来说,使用stable-baselines3搭建PPO算法的步骤包括安装库、创建模型对象、定义环境、执行训练和保存模型、使用训练好的模型进行测试和评估。这些步骤可以帮助我们构建一个基于PPO算法的强化学习模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值