业精于勤荒于嬉,行成于思毁于随。
从这篇开始,就开始记录stable
一、安装
#环境要求
#python 3.9
#PyTorch >= 1.11
#二选一即可
#!pip install stable-baselines3[extra]
#!pip install git+https://github.com/DLR-RM/stable-baselines3
#建议使用0.26.2版本的gym
#!pip install gym==0.26.2
#测试
import stable_baselines3 as sb3
import gym
sb3.__version__, gym.__version__
要使用SB3,第一步肯定是在虚拟环境里安装,关于anaconda创建虚拟环境,如果有不熟悉的,可以参考我的这篇博客:anaconda安装,安装的步骤还是比较简单的,无论是Linux系统还是Windows系统,都可以通过
pip install stable-baselines3[extra]
或者
pip install stable-baselines3
来进行安装。
推荐安装1.8.0版本,这个版本可以运行很多开源代码。[extra]版本会多一些内容,比如Tensorboard, OpenCV,其他的没有区别。
我们也可以把库远程克隆出来,然后进行安装:
pip install git+https://github.com/DLR-RM/stable-baselines3
使用extra版本的话,指令如下:
pip install "stable_baselines3[extra,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
在cd
进入文件夹之后:
pip install -e .[docs,tests,extra]
安装还是很简单的,github上的readme文件也比较详细。stable-baselines3的github链接如下:SB3
二、基本使用
SB3的封装程度比较高,基本的使用是比较简单的:
# SB3封装了很多强化学习的算法,可以直接导入使用
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
#verbose: (int) Verbosity level 0: not output 1: info 2: debug
# 我们在这里将算法模型实例化,并传入环境
model = PPO('MlpPolicy', env, verbose=0)
# 这一步开始训练
model.learn(total_timesteps=2_0000, progress_bar=True)
# 也可以写成如下的形式:
# model = PPO('MlpPolicy', env, verbose=0).learn(total_timesteps=2_0000, progress_bar=True)
# 保存模型
model.save("../../model")
evaluate_policy(model, env, n_eval_episodes=20)
# 加载模型
model.load("../../model")
......
三、自定义wrapper
我们是通过wrapper来包装环境,所以我们也可以对其内容进行自定义,以便训练自己的强化学习环境
定义环境
import gym
#定义环境
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('Pendulum-v1')
super().__init__(env)
self.env = env
def reset(self):
state, _ = self.env.reset()
return state
def step(self, action):
state, reward, done, _, info = self.env.step(action)
return state, reward, done, info
Pendulum().reset()
如果我们想对其step的步数进行限制,我们需要有一个变量记录当前的step,然后每一次运行都要自增,最后在step函数里写明限制条件,我们可以在封装的类中修改代码:
def __init__(self, env):
super().__init__(env)
self.current_step = 0
def reset(self):
self.current_step = 0
return self.env.reset()
def step(self, action):
self.current_step += 1
state, reward, done, info = self.env.step(action)
#修改done字段
if self.current_step >= 100:
done = True
return state, reward, done, info
SB3的动作空间和观察空间接收Box类型的变量,如果我们需要修改动作空间或者观察空间,可以使用gym.spaces.Box来进行修改(数据类型默认float32,我们可以在box.py文件里修改默认值为float64,或者标注dtype=float64):
env.action_space = gym.spaces.Box(low=-1,high=1)
observation_space = spaces.Box(low=-1,high=1)
测试环境
#测试环境
def test(env, wrap_action_in_list=False):
print(env)
state = env.reset()
over = False
step = 0
while not over:
action = env.action_space.sample()
if wrap_action_in_list:
action = [action]
next_state, reward, over, _ = env.step(action)
if step % 20 == 0:
print(step, state, action, reward)
if step > 200:
break
state = next_state
step += 1
test(Pendulum())
四、多环境训练
DummyVecEnv
是在单线程中运行多个环境,SubprocVecEnv
是在多线程中运行多个环境
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
import time
def test_multiple_env(dumm, N):
if dumm:
#DummyVecEnv,在单线程中运行多个环境
env = DummyVecEnv([MyWrapper] * N)
else:
#SubprocVecEnv,在多线程中运行多个环境
env = SubprocVecEnv([MyWrapper] * N, start_method='fork')
start = time.time()
#训练一个模型
model = PPO('MlpPolicy', env, verbose=0).learn(total_timesteps=5000)
print('消耗时间=', time.time() - start)
#关闭环境
env.close()
#测试
return evaluate_policy(model, MyWrapper(), n_eval_episodes=20)
test_multiple_env(dumm=True, N=2)
我们可以通过上面的函数,来决定是使用DummyVecEnv
还是SubprocVecEnv
,并对其进行对比:
test_multiple_env(dumm=True, N=10)
test_multiple_env(dumm=False, N=10)
五、CallBack类
SB3的用户指南里是这么描述Callback类的:
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
"""
A custom callback that derives from ``BaseCallback``.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, verbose: int = 0):
super().__init__(verbose)
# Those variables will be accessible in the callback
# (they are defined in the base class)
# The RL model
# self.model = None # type: BaseAlgorithm
# An alias for self.model.get_env(), the environment used for training
# self.training_env # type: VecEnv
# Number of time the callback was called
# self.n_calls = 0 # type: int
# num_timesteps = n_envs * n times env.step() was called
# self.num_timesteps = 0 # type: int
# local and global variables
# self.locals = {} # type: Dict[str, Any]
# self.globals = {} # type: Dict[str, Any]
# The logger object, used to report things in the terminal
# self.logger # type: stable_baselines3.common.logger.Logger
# Sometimes, for event callback, it is useful
# to have access to the parent object
# self.parent = None # type: Optional[BaseCallback]
def _on_training_start(self) -> None:
"""
This method is called before the first rollout starts.
"""
pass
def _on_rollout_start(self) -> None:
"""
A rollout is the collection of environment interaction
using the current policy.
This event is triggered before collecting new samples.
"""
pass
def _on_step(self) -> bool:
"""
This method will be called by the model after each call to `env.step()`.
For child callback (of an `EventCallback`), this will be called
when the event is triggered.
:return: If the callback returns False, training is aborted early.
"""
return True
def _on_rollout_end(self) -> None:
"""
This event is triggered before updating the policy.
"""
pass
def _on_training_end(self) -> None:
"""
This event is triggered before exiting the `learn()` method.
"""
pass
我们的代码参考国内的GPT链接,链接如下:GPT通道