stable-baselines3[稳定基线]第一篇:Stable-baselines3基本知识

业精于勤荒于嬉,行成于思毁于随。
从这篇开始,就开始记录stable

一、安装

#环境要求
#python 3.9
#PyTorch >= 1.11

#二选一即可
#!pip install stable-baselines3[extra]
#!pip install git+https://github.com/DLR-RM/stable-baselines3

#建议使用0.26.2版本的gym
#!pip install gym==0.26.2

#测试
import stable_baselines3 as sb3
import gym

sb3.__version__, gym.__version__

要使用SB3,第一步肯定是在虚拟环境里安装,关于anaconda创建虚拟环境,如果有不熟悉的,可以参考我的这篇博客:anaconda安装,安装的步骤还是比较简单的,无论是Linux系统还是Windows系统,都可以通过

pip install stable-baselines3[extra]

或者

pip install stable-baselines3

来进行安装。

推荐安装1.8.0版本,这个版本可以运行很多开源代码。[extra]版本会多一些内容,比如Tensorboard, OpenCV,其他的没有区别。

我们也可以把库远程克隆出来,然后进行安装:

pip install git+https://github.com/DLR-RM/stable-baselines3

使用extra版本的话,指令如下:

pip install "stable_baselines3[extra,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"

cd进入文件夹之后:

pip install -e .[docs,tests,extra]

安装还是很简单的,github上的readme文件也比较详细。stable-baselines3的github链接如下:SB3

二、基本使用

SB3的封装程度比较高,基本的使用是比较简单的:

# SB3封装了很多强化学习的算法,可以直接导入使用
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy

#verbose: (int) Verbosity level 0: not output 1: info 2: debug
# 我们在这里将算法模型实例化,并传入环境
model = PPO('MlpPolicy', env, verbose=0)
# 这一步开始训练
model.learn(total_timesteps=2_0000, progress_bar=True)
# 也可以写成如下的形式:
# model = PPO('MlpPolicy', env, verbose=0).learn(total_timesteps=2_0000, progress_bar=True)
# 保存模型
model.save("../../model")
evaluate_policy(model, env, n_eval_episodes=20)
# 加载模型
model.load("../../model")
......

三、自定义wrapper

我们是通过wrapper来包装环境,所以我们也可以对其内容进行自定义,以便训练自己的强化学习环境

定义环境

import gym
#定义环境
class MyWrapper(gym.Wrapper):

    def __init__(self):
        env = gym.make('Pendulum-v1')
        super().__init__(env)
        self.env = env
 
    def reset(self):
        state, _ = self.env.reset()
        return state

    def step(self, action):
        state, reward, done, _, info = self.env.step(action)
        return state, reward, done, info
        
Pendulum().reset()

如果我们想对其step的步数进行限制,我们需要有一个变量记录当前的step,然后每一次运行都要自增,最后在step函数里写明限制条件,我们可以在封装的类中修改代码:

  def __init__(self, env):
      super().__init__(env)
      self.current_step = 0

  def reset(self):
      self.current_step = 0
      return self.env.reset()

  def step(self, action):
      self.current_step += 1
      state, reward, done, info = self.env.step(action)

      #修改done字段
      if self.current_step >= 100:
          done = True

      return state, reward, done, info

SB3的动作空间和观察空间接收Box类型的变量,如果我们需要修改动作空间或者观察空间,可以使用gym.spaces.Box来进行修改(数据类型默认float32,我们可以在box.py文件里修改默认值为float64,或者标注dtype=float64):

env.action_space = gym.spaces.Box(low=-1,high=1)
observation_space = spaces.Box(low=-1,high=1)

测试环境

#测试环境
def test(env, wrap_action_in_list=False):
    print(env)

    state = env.reset()
    over = False
    step = 0

    while not over:
        action = env.action_space.sample()

        if wrap_action_in_list:
            action = [action]

        next_state, reward, over, _ = env.step(action)

        if step % 20 == 0:
            print(step, state, action, reward)

        if step > 200:
            break

        state = next_state
        step += 1


test(Pendulum())

四、多环境训练

DummyVecEnv是在单线程中运行多个环境,SubprocVecEnv是在多线程中运行多个环境

from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
import time


def test_multiple_env(dumm, N):
    if dumm:
        #DummyVecEnv,在单线程中运行多个环境
        env = DummyVecEnv([MyWrapper] * N)
    else:
        #SubprocVecEnv,在多线程中运行多个环境
        env = SubprocVecEnv([MyWrapper] * N, start_method='fork')

    start = time.time()

    #训练一个模型
    model = PPO('MlpPolicy', env, verbose=0).learn(total_timesteps=5000)

    print('消耗时间=', time.time() - start)

    #关闭环境
    env.close()

    #测试
    return evaluate_policy(model, MyWrapper(), n_eval_episodes=20)


test_multiple_env(dumm=True, N=2)

我们可以通过上面的函数,来决定是使用DummyVecEnv还是SubprocVecEnv,并对其进行对比:

test_multiple_env(dumm=True, N=10)
test_multiple_env(dumm=False, N=10)

五、CallBack类

SB3的用户指南里是这么描述Callback类的:

from stable_baselines3.common.callbacks import BaseCallback


class CustomCallback(BaseCallback):
    """
    A custom callback that derives from ``BaseCallback``.

    :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
    """
    def __init__(self, verbose: int = 0):
        super().__init__(verbose)
        # Those variables will be accessible in the callback
        # (they are defined in the base class)
        # The RL model
        # self.model = None  # type: BaseAlgorithm
        # An alias for self.model.get_env(), the environment used for training
        # self.training_env # type: VecEnv
        # Number of time the callback was called
        # self.n_calls = 0  # type: int
        # num_timesteps = n_envs * n times env.step() was called
        # self.num_timesteps = 0  # type: int
        # local and global variables
        # self.locals = {}  # type: Dict[str, Any]
        # self.globals = {}  # type: Dict[str, Any]
        # The logger object, used to report things in the terminal
        # self.logger # type: stable_baselines3.common.logger.Logger
        # Sometimes, for event callback, it is useful
        # to have access to the parent object
        # self.parent = None  # type: Optional[BaseCallback]

    def _on_training_start(self) -> None:
        """
        This method is called before the first rollout starts.
        """
        pass

    def _on_rollout_start(self) -> None:
        """
        A rollout is the collection of environment interaction
        using the current policy.
        This event is triggered before collecting new samples.
        """
        pass

    def _on_step(self) -> bool:
        """
        This method will be called by the model after each call to `env.step()`.

        For child callback (of an `EventCallback`), this will be called
        when the event is triggered.

        :return: If the callback returns False, training is aborted early.
        """
        return True

    def _on_rollout_end(self) -> None:
        """
        This event is triggered before updating the policy.
        """
        pass

    def _on_training_end(self) -> None:
        """
        This event is triggered before exiting the `learn()` method.
        """
        pass

我们的代码参考国内的GPT链接,链接如下:GPT通道

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值