StableBaselines3强化学习框架简明教程

由强雷师兄友情支持。

Stable Baselines3 (SB3) 是基于 PyTorch 的强化学习库,提供多种经典强化学习算法(如 PPO、DQN、A2C 等)。

资源链接:

1. 官方文档:

Stable-Baselines3 Docs - Reliable Reinforcement Learning Implementations — Stable Baselines3 2.6.0 documentation

2. GitHub 仓库:

https://github.com/lansinuote/StableBaselines3_SimpleCases/tree/main

1. 安装

需要先安装Python(3.9+)、PyTorch >= 2.3 和 gymnasium(替代旧版gym)

# 安装 Stable Baselines3 和依赖库
pip install stable-baselines3[extra]

详见网页

2. 快速使用

SB3的主要功能就是封装了底层代码,这样在使用的时候就可以快速调用并自动调参。

以下为核心步骤:

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor


# 定义自定义环境包装器
class MyWrapper(gym.Wrapper):

    def __init__(self, env):
        super().__init__(env)

    def reset(self, **kwargs):
        # 确保 reset 支持 seed 参数,并返回 state 和 info
        state, info = self.env.reset(**kwargs)
        return state, info  # 返回两个值

    def step(self, action):
        # 确保 step 返回值符合 Gymnasium 标准
        state, reward, terminated, truncated, info = self.env.step(action)
        return state, reward, terminated, truncated, info  # 返回五个值


# 创建原始环境并包装
env = gym.make('CartPole-v1')
env = MyWrapper(env)
env = Monitor(env)  # 使用 Monitor 包装器记录统计数据

# 测试环境重置
env.reset()

# 创建 PPO 模型
model = PPO('MlpPolicy', env, verbose=0)

# 在训练前测试模型性能
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20)
print(f"Before training: mean_reward={mean_reward:.2f} +/- {std_reward:.2f}")

# 训练模型
model.learn(total_timesteps=20, progress_bar=True)

# 在训练后测试模型性能
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20)
print(f"After training: mean_reward={mean_reward:.2f} +/- {std_reward:.2f}")

注意:SB3 不再直接支持旧版的 gym 库(如 gym 0.26 或更高版本)。SB3 推荐使用 gymnasium(即 gym 的分支项目)作为环境接口,以确保兼容性和持续维护。 Gymnasium 和Gym 某些接口设计上存在差异,reset() 方法和step() 方法返回值的数量不同,请注意修改环境的相关代码哈。🙂

3 .模型的训练和保存

同上定义游戏环境后,模型训练代码如下:

import gymnasium as gym
from stable_baselines3 import PPO

# 创建 PPO 模型
model = PPO('MlpPolicy', env, verbose=0).learn(total_timesteps=8000, progress_bar=True)
# progress bar 表示进度条

# 保存模型
model.save('models/save')

# 加载模型
model = PPO.load('models/save')

model

运行后模型会自动保存在 models 目录下的save.zip文件里。

4. 进阶--包装Gym环境

在SB3中,可以对动作空间重新定义。

1. 动作空间的缩放

在下面这个例子中,原问题中动作空间为\in [-2, 2]的连续值,但是在定义动作空间时我们将其压缩至\in [-1,1],又在step() 中重新缩放为原值域。具体操作如下:

import numpy as np


#修改动作空间
class NormalizeActionWrapper(gym.Wrapper):

    def __init__(self, env):
        #获取动作空间
        action_space = env.action_space

        #动作空间必须是连续值
        assert isinstance(action_space, gym.spaces.Box)

        #重新定义动作空间,在正负一之间的连续值
        #这里其实只影响env.action_space.sample的返回结果
        #实际在计算时,还是正负2之间计算的
        env.action_space = gym.spaces.Box(low=-1,
                                          high=1,
                                          shape=action_space.shape,
                                          dtype=np.float32)

        super().__init__(env)

    def reset(self):
        return self.env.reset()

    def step(self, action):
        #重新缩放动作的值域
        action = action * 2.0

        if action > 2.0:
            action = 2.0

        if action < -2.0:
            action = -2.0

        return self.env.step(action)


test(NormalizeActionWrapper(Pendulum()))

2.  修改环境的状态空间

如下面的例子。原状态空间为3列,定义 StateStepWrapper 方法新增一列在 [0,1] 之间的状态。

    from gym.wrappers import TimeLimit
    import numpy as np
    
    
    class StateStepWrapper(gym.Wrapper):
    
        def __init__(self, env):
            # 状态空间必须是连续值
            assert isinstance(env.observation_space, gym.spaces.Box)
    
            # 增加一个新状态字段
            low = np.concatenate([env.observation_space.low, [0.0]])
            high = np.concatenate([env.observation_space.high, [1.0]])
    
            env.observation_space = gym.spaces.Box(low=low,
                                                   high=high,
                                                   dtype=np.float32)
    
            super().__init__(env)
    
            self.step_current = 0
    
        def reset(self):
            self.step_current = 0
            return np.concatenate([self.env.reset(), [0.0]])
    
        def step(self, action):
            self.step_current += 1
            state, reward, done, info = self.env.step(action)
    
            # 根据 step_max 修改 done
            if self.step_current >= 100:
                done = True
    
            return self.get_state(state), reward, done, info
    
        def get_state(self, state):
            # 添加一个新的 state 字段
            state_step = self.step_current / 100
    
            return np.concatenate([state, [state_step]])

    3. Normalization 归一化

    SB3 内置的Wapper:VecNormalize 也可以对 state 和 reward 进行 归一化。具体操作如下:

    import gymnasium as gym
    from stable_baselines3 import PPO
    from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
    
    # 创建原始环境
    env = gym.make("Pendulum-v1")
    
    # 矢量化环境
    env = DummyVecEnv([lambda: env])
    
    # 应用 VecNormalize
    env = VecNormalize(env, 
                      norm_obs=True, 
                      norm_reward=True, 
                      clip_obs=10., 
                      clip_reward=10.)
    
    # 创建并训练模型
    model = PPO("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=10_000)
    
    # 保存模型和归一化参数
    model.save("ppo_pendulum")
    env.save("vec_normalize.pkl")
    
    # 加载模型和归一化参数
    model = PPO.load("ppo_pendulum")
    env = VecNormalize.load("vec_normalize.pkl", env)
    env.training = False
    env.norm_reward = True  # 如果需要加载奖励归一化参数
    
    # 测试模型
    obs, _ = env.reset()
    for _ in range(1000):
        action, _states = model.predict(obs)
        obs, rewards, dones, truncs, info = env.step(action)
        if dones[0]:
            obs, _ = env.reset()

    5. 进阶--参数调优

    • 学习率learning_rate=3e-4
    • 批处理大小batch_size=64
    • 折扣因子gamma=0.99(未来奖励的权重)
    • 网络结构:通过policy_kwargs自定义网络
    from stable_baselines3.common.policies import ActorCriticPolicy
    
    policy_kwargs = dict(
        net_arch=[dict(pi=[128, 128], vf=[128, 128])]  # 策略和价值网络的层数
    )
    model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs)

    评论
    添加红包

    请填写红包祝福语或标题

    红包个数最小为10个

    红包金额最低5元

    当前余额3.43前往充值 >
    需支付:10.00
    成就一亿技术人!
    领取后你会自动成为博主和红包主的粉丝 规则
    hope_wisdom
    发出的红包
    实付
    使用余额支付
    点击重新获取
    扫码支付
    钱包余额 0

    抵扣说明:

    1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
    2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

    余额充值