callback是一系列函数在训练程序中会被called在给定阶段,您可以在训练期间使用callback来访问 RL 模型的内部状态。它允许人们进行监控,自动保存,模型操控,进度条等
1. 自定义callback
自定义一个callback需要继承自BaseCallback
,这将使您能够访问事件 ( _on_training_start
, _on_step
) 和有用的变量(例如RL 模型的self.model
)
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
"""
A custom callback that derives from ``BaseCallback``.
:param verbose: (int) Verbosity level 0: not output 1: info 2: debug
"""
def __init__(self, verbose=0):
super(CustomCallback, self).__init__(verbose)
# Those variables will be accessible in the callback
# (they are defined in the base class)
# The RL model
# self.model = None # type: BaseAlgorithm
# An alias for self.model.get_env(), the environment used for training
# self.training_env = None # type: Union[gym.Env, VecEnv, None]
# Number of time the callback was called
# self.n_calls = 0 # type: int
# self.num_timesteps = 0 # type: int
# local and global variables
# self.locals = None # type: Dict[str, Any]
# self.globals = None # type: Dict[str, Any]
# The logger object, used to report things in the terminal
# self.logger = None # stable_baselines3.common.logger
# # Sometimes, for event callback, it is useful
# # to have access to the parent object
# self.parent = None # type: Optional[BaseCallback]
def _on_training_start(self) -> None:
"""
This method is called before the first rollout starts.
第一个rollout开始之前这个方法被call
"""
pass
def _on_rollout_start(self) -> None:
"""
A rollout is the collection of environment interaction
using the current policy.
This event is triggered before collecting new samples.
新的sample被收集时该方法被触发
"""
pass
def _on_step(self) -> bool:
"""
This method will be called by the model after each call to `env.step()`.
这个方法被model call在每次进行env.step()
For child callback (of an `EventCallback`), this will be called
when the event is triggered.
:return: (bool) If the callback returns False, training is aborted early.
"""
return True
def _on_rollout_end(self) -> None:
"""
This event is triggered before updating the policy.
更新policy前该方法被触发
"""
pass
def _on_training_end(self) -> None:
"""
This event is triggered before exiting the `learn()` method.
离开model.learn()之前被触发
"""
pass
2. Event Callback
Stable Baselines provides a second type of BaseCallback
, named EventCallback
that is meant to trigger events. When an event is triggered, then a child callback is called.
SB提供了第二个BaseCallback
叫做EventCallback
来接受事件触发,当一个事件被触发时,一个child callback被call
As an example, EvalCallback
is an EventCallback
that will trigger its child callback when there is a new best model. A child callback is for instance StopTrainingOnRewardThreshold
that stops the training if the mean reward achieved by the RL model is above a threshold.
举一个例子,Evalcallback
是一个EventCallback
,当一个新的最好的model出现时,它将触发它的child callback,比如StopTrainingOnRewardThreshold
(它将使训练停止如果RL model的表现的平均奖励达到阈值)
class EventCallback(BaseCallback):
"""
Base class for triggering callback on event.
:param callback: (Optional[BaseCallback]) Callback that will be called
when an event is triggered.
:param verbose: (int)
"""
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
super(EventCallback, self).__init__(verbose=verbose)
self.callback = callback
# Give access to the parent
if callback is not None:
self.callback.parent = self
...
def _on_event(self) -> bool:
if self.callback is not None:
return self.callback()
return True
3. Callback Collection
Stable Baselines provides you with a set of common callbacks for:
- saving the model periodically (CheckpointCallback) 定期保存模型
- evaluating the model periodically and saving the best one (EvalCallback) 定期评估模型并保存最好模型
- chaining callbacks (CallbackList)
- triggering callback on events (Event Callback, EveryNTimesteps)
- stopping the training early based on a reward threshold (StopTrainingOnRewardThreshold) 提前结束训练当reward达到阈值
3.1CheckpointCallback
Callback for saving a model every save_freq
calls to env.step()
, you must specify a log folder (save_path
) and optionally a prefix for the checkpoints (rl_model
by default).
每save_freq
次进行 env.step()
保存一下模型,你必须指定一个文件目录 (save_path
) 并且可以选择一个前缀(默认用rl_model
)
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CheckpointCallback
# Save a checkpoint every 1000 steps
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/',
name_prefix='rl_model')
model = SAC('MlpPolicy', 'Pendulum-v0')
model.learn(2000, callback=checkpoint_callback)
3.2EvalCallback
Evaluate periodically the performance of an agent, using a separate test environment. It will save the best model if best_model_save_path
folder is specified and save the evaluations results in a numpy archive (evaluations.npz
) if log_path
folder is specified.
用一个分离的测试环境周期性的评估agent一段时间的表现,如果best_model_save_path
路径被指定它将保存最好的模型并且保存评估结果在(evaluations.npz
) 中如果log_path
路径被指定。
import gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback
# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
# Use deterministic actions for evaluation
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',
log_path='./logs/', eval_freq=500,
deterministic=True, render=False)
model = SAC('MlpPolicy', 'Pendulum-v0')
model.learn(5000, callback=eval_callback)
3.3CallbackList
Class for chaining callbacks, they will be called sequentially. Alternatively或者, you can pass directly a list of callbacks to the learn()
method, it will be converted automatically to a CallbackList
.
import gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs3/')
# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs3/best_model',
log_path='./logs3/results', eval_freq=500)
# Create the callback list
callback = CallbackList([checkpoint_callback, eval_callback])
model = SAC('MlpPolicy', 'Pendulum-v0')
# Equivalent to:
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
model.learn(5000, callback=callback)
3.4StopTrainingOnRewardThreshold
该部分可结合 2. Event Callback学习
Stop the training once a threshold in episodic reward (mean episode reward over the evaluations) has been reached (i.e., when the model is good enough). It must be used with the EvalCallback and use the event triggered by a new best model.
import gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
# Stop training when the model reaches the reward threshold
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)
model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the reward threshold is reached
model.learn(int(1e10), callback=eval_callback)
3.5EveryNTimesteps
An Event Callback that will trigger its child callback every n_steps
timesteps.
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps
# this is equivalent to defining CheckpointCallback(save_freq=500)
# checkpoint_callback will be triggered every 500 steps
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs4/')
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
model = PPO('MlpPolicy', 'Pendulum-v0', verbose=1)
model.learn(int(2e4), callback=event_callback)
3.6StopTrainingOnMaxEpisodes
Stop the training upon reaching the maximum number of episodes, regardless of the model’s total_timesteps
value. Also, presumes that, for multiple environments, the desired behavior is that the agent trains on each env for max_episodes
and in total for max_episodes * n_envs
episodes.
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes
# Stops training when the model reaches the maximum number of episodes
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)
model = A2C('MlpPolicy', 'Pendulum-v0', verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the max number of episodes is reached
model.learn(int(1e10), callback=callback_max_episodes)
3.7StopTrainingOnNoModelImprovement
Stop the training if there is no new best model (no new best mean reward) after more than a specific number of consecutive连续的 evaluations. The idea is to save time in experiments when you know that the learning curves are somehow well behaved and, therefore, after many evaluations without improvement the learning has probably stabilized. It must be used with the EvalCallback and use the event triggered after every evaluation.
如果在超过特定数量的连续评估后没有新的最佳模型(没有新的最佳平均奖励),则停止训练。当你知道学习曲线在某种程度上表现良好时,这个想法是为了节省实验时间,因此,经过多次评估而没有改进,学习可能已经稳定下来。它必须与EvalCallback一起使用,并使用每次评估后触发的事件。
import gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
# Separate evaluation env
eval_env = gym.make("Pendulum-v0")
# Stop training if there is no improvement after more than 3 evaluations
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=3, min_evals=5, verbose=1)
eval_callback = EvalCallback(eval_env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1)
model = SAC("MlpPolicy", "Pendulum-v0", learning_rate=1e-3, verbose=1)
# Almost infinite number of timesteps, but the training will stop early
# as soon as the the number of consecutive evaluations without model
# improvement is greater than 3
model.learn(int(1e10), callback=eval_callback)
这个我看callcbacks里没有这个类,可能是我用的stb3版本里没有实现这个类