强化学习框架stable-baselines3简单案例

Stable-Baselines3 (SB3) 是一个基于 PyTorch 的库,提供了可靠的强化学习算法实现。它拥有简洁易用的接口,让用户能够直接使用现成的、最先进的无模型强化学习算法。
在这里插入图片描述

以下是一个基于强化学习和 Gym 中 mujocoAnt 环境的案例,使用了 Proximal Policy Optimization (PPO) 算法,这是一个适用于连续状态和动作空间的强化学习算法。


环境准备

安装依赖

确保安装以下库:

pip install gym[mujoco] stable-baselines3 shimmy
  • gym[mujoco]: 提供 MuJoCo 环境支持。
  • stable-baselines3: 包含多种强化学习算法的库,包括 PPO。
  • shimmy: stable-baselines3需要用到shimmy。

完整代码

实现 PPO 与 Ant 环境交互
import gym
from stable_baselines3 import PPO
import imageio

# 创建 Ant 环境l
env = gym.make("Ant-v4")

# 使用 Stable-Baselines3 的 PPO 算法
model = PPO(
    "MlpPolicy",  # 多层感知机作为策略网络
    env,
    verbose=1,
    learning_rate=0.0003,
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
)

# 训练模型, total_timesteps自行调整
model.learn(total_timesteps=100000)

# 保存模型
model.save("ppo_ant")
# 加载模型
model = PPO.load("ppo_ant")
# 创建测试环境
env = gym.make("Ant-v4", render_mode="rgb_array")
# 存储每帧图像
frames = []
# 测试模型
obs, info = env.reset()
for _ in range(1000):
    env.render()
    frames.append(env.render())  # 捕获帧
    action, _ = model.predict(obs)
    next_state, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        obs, info = env.reset()

env.close()

# 保存为视频
imageio.mimsave("./ppo_ant_video.mp4", frames, fps=30)


代码解析

  1. 创建 Ant 环境

    • 使用 gym.make("Ant-v4") 创建 Ant 环境。
  2. 使用 PPO 算法

    • 策略网络:使用 MlpPolicy(多层感知机策略)。
    • 超参数设置
      • learning_rate:学习率,控制更新步长。
      • n_steps:每次更新前的时间步数。
      • batch_size:训练时的批量大小。
      • n_epochs:每次更新的训练轮数。
      • gamma:折扣因子,权衡短期与长期奖励。
      • gae_lambda:广义优势估计(GAE)的参数,用于稳定学习。
  3. 训练模型

    • 使用 model.learn() 函数训练模型。
  4. 测试模型

    • 使用 model.predict(obs) 获得动作决策。
    • 在环境中运行训练好的策略,通过渲染观察蚂蚁机器人的运动行为。

运行结果如下

若训练轮次较少,蚂蚁会翻倒

ppo训练轮数较少的情况

训练100000轮后,蚂蚁不再会翻倒

ppo训练轮数较多的情况


关键点与挑战

  1. 动作控制

    • 机器人通过连续动作控制腿部关节,需要策略学习如何协调运动。
    • 强化学习算法需要在高维动作空间中找到最优策略。
  2. 奖励函数设计

    • 环境自带的奖励函数主要基于蚂蚁的前进速度和能量效率。
    • 奖励设计需平衡速度、稳定性和能量消耗。
  3. 计算复杂度

    • 高维状态和动作空间会增加学习的难度,需要更长时间训练。

扩展方向

  1. 改进奖励函数

    • 自定义奖励函数,例如鼓励更多的能量效率或更复杂的步态。
  2. 多任务学习

    • Ant 环境中添加不同目标,例如绕过障碍或追踪目标点。
  3. 模型性能对比

    • 试验其他强化学习算法(如 DDPG、SAC、TD3),对比训练速度与性能。
  4. 迁移学习

    • 将训练好的蚂蚁策略应用于其他机器人环境,测试泛化能力。

总结

经过训练,蚂蚁机器人能够学会如何行走并避免翻倒。最终表现取决于训练时间和算法参数设置。渲染结果可以显示蚂蚁运动的动画效果。

笔者水平有限,若有不对的地方欢迎评论指正!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

肥猪猪爸

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值