基于stable-baseline3 强化学习DQN的lunar lander的稳定控制

依赖包

鉴于不同版本的gym与stable-baselines3会产生冲突,在成功的基础上记录:
gym == 0.21.0
stable-baselines3 == 1.6.2
安装代码:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gym==0.21.0
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple stable-baselines3[extra]==1.6.2

lunar lander随机初始化action

import gym


# Create environment
env = gym.make("LunarLander-v2")

eposides = 10
for eq in range(eposides):
    obs = env.reset()
    done = False
    rewards = 0
    while not done:
        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)
        env.render()
        rewards += reward
    print(rewards)

随机初始化,视频链接:lunar_lander_random

基于stable-baseline中DQN的实现

模型训练

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy


# Create environment
env = gym.make("LunarLander-v2")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
# Train the agent and display a progress bar
model.learn(total_timesteps=int(2e5), progress_bar=True)
# Save the agent
model.save("dqn_lunar")

这里已经将训练好的模型给保存为dqn_lunar.zip

模型测试

直接读取模型训练结果,进行测试

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy


# Create environment
env = gym.make("LunarLander-v2")
model = DQN.load("dqn_lunar", env=env)


# 测试接口
mean_reward, std_reward = evaluate_policy(
    model,
    model.get_env(),
    deterministic=True,
    render=True,
    n_eval_episodes=10)
print(mean_reward)

自己写测试模块

import gym
from stable_baselines3 import DQN


# Create environment
env = gym.make("LunarLander-v2")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
model = DQN.load("dqn_lunar", env=env)


eposides = 10
for eq in range(eposides):
    obs = env.reset()
    done = False
    rewards = 0
    while not done:
        action, _state = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(action)
        env.render()
        rewards += reward
    print(rewards)

测试结果:lunar_lander_DQN

网络架构优化

根据上述视频可以看出,在默认的DQN网络及参数,还不能使飞行器稳定停在月球上,将学习率改为5e-4,网络参数改为256,训练次数改为2500,000次,训练代码如下:

import gym
from stable_baselines3 import DQN


# Create environment
env = gym.make("LunarLander-v2")
model = DQN(
    "MlpPolicy",
    env,
    verbose=1,
    learning_rate=5e-4,
    policy_kwargs={'net_arch':[256,256]})
    
model.learn(
    total_timesteps=int(2.5e6),
    progress_bar=True)

model.save("dqn_Net256_lunar_2500K")

模型测试代码如下:

import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy


# Create environment
env = gym.make("LunarLander-v2")
model = DQN.load("dqn_Net256_lunar_2500K", env=env)

mean_reward, std_reward = evaluate_policy(
    model,
    model.get_env(),
    deterministic=True,
    render=True,
    n_eval_episodes=10)
print(mean_reward)

测试视频:lunar_lander_256_2500K
由视频可以看出,月球车每次都能稳定停留在月球表面。

附录

有问题可以直接查官方文档
stable-baseline3: 手册
gym: 手册

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值