【2021-12-07】基于CartPole环境训练

本文介绍了在强化学习环境中,如何利用回调如EarlyStopping、ModelCheckpoint和LearningRateScheduler来优化训练过程。EarlyStopping用于防止过拟合,ModelCheckpoint保存最佳模型权重,LearningRateScheduler动态调整学习率。同时,讨论了改变策略和使用替代算法的可能性。
摘要由CSDN通过智能技术生成

import pakages

import os
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

Load environments

environment_name = 'CartPole-v0'
env = gym.make(environment_name)
episodes = 5
for episode in range(1,episodes+1):
    state = env.reset()
    done = False
    score = 0
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state,reward,done,info = env.step(action)
        score += reward
    print('episode:{} score:{}'.format(episode,score))
env.close()
episode:1 score:19.0
episode:2 score:28.0
episode:3 score:23.0
episode:4 score:36.0
episode:5 score:26.0

env.reset()

train an agent

# make your directories first
log_path = os.path.join('Training','Logs')
log_path
'Training\\Logs'
env = gym.make(environment_name)
env = DummyVecEnv([lambda:env])
model = PPO('MlpPolicy',env,verbose = 1,tensorboard_log = log_path)
Using cpu device
model.learn(total_timesteps = 20000)
Logging to Training\Logs\PPO_4
-----------------------------
| time/              |      |
|    fps             | 923  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 952         |
|    iterations           | 2           |
|    time_elapsed         | 4           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008253496 |
|    clip_fraction        | 0.094       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | 0.00826     |
|    learning_rate        | 0.0003      |
|    loss                 | 5.01        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0152     |
|    value_loss           | 51.7        |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 953         |
|    iterations           | 3           |
|    time_elapsed         | 6           |
|    total_timesteps      | 6144        |
| train/                  |             |
|    approx_kl            | 0.008827355 |
|    clip_fraction        | 0.0568      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.667      |
|    explained_variance   | 0.08        |
|    learning_rate        | 0.0003      |
|    loss                 | 14.8        |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.0173     |
|    value_loss           | 38.1        |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 953         |
|    iterations           | 4           |
|    time_elapsed         | 8           |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.013173467 |
|    clip_fraction        | 0.157       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.62       |
|    explained_variance   | 0.319       |
|    learning_rate        | 0.0003      |
|    loss                 | 22.3        |
|    n_updates            | 30          |
|    policy_gradient_loss | -0.0252     |
|    value_loss           | 43.1        |
-----------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 951          |
|    iterations           | 5            |
|    time_elapsed         | 10           |
|    total_timesteps      | 10240        |
| train/                  |              |
|    approx_kl            | 0.0064320453 |
|    clip_fraction        | 0.0451       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.613       |
|    explained_variance   | 0.22         |
|    learning_rate        | 0.0003       |
|    loss                 | 19.4         |
|    n_updates            | 40           |
|    policy_gradient_loss | -0.0138      |
|    value_loss           | 62.2         |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 952         |
|    iterations           | 6           |
|    time_elapsed         | 12          |
|    total_timesteps      | 12288       |
| train/                  |             |
|    approx_kl            | 0.005620899 |
|    clip_fraction        | 0.0605      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.596      |
|    explained_variance   | 0.529       |
|    learning_rate        | 0.0003      |
|    loss                 | 17.2        |
|    n_updates            | 50          |
|    policy_gradient_loss | -0.016      |
|    value_loss           | 52.4        |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 951         |
|    iterations           | 7           |
|    time_elapsed         | 15          |
|    total_timesteps      | 14336       |
| train/                  |             |
|    approx_kl            | 0.010773982 |
|    clip_fraction        | 0.115       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.6        |
|    explained_variance   | 0.781       |
|    learning_rate        | 0.0003      |
|    loss                 | 8.86        |
|    n_updates            | 60          |
|    policy_gradient_loss | -0.0135     |
|    value_loss           | 41.3        |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 955         |
|    iterations           | 8           |
|    time_elapsed         | 17          |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.007079646 |
|    clip_fraction        | 0.0451      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.562      |
|    explained_variance   | 0.264       |
|    learning_rate        | 0.0003      |
|    loss                 | 55.1        |
|    n_updates            | 70          |
|    policy_gradient_loss | -0.00813    |
|    value_loss           | 93.9        |
-----------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 957          |
|    iterations           | 9            |
|    time_elapsed         | 19           |
|    total_timesteps      | 18432        |
| train/                  |              |
|    approx_kl            | 0.0034424942 |
|    clip_fraction        | 0.0265       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.581       |
|    explained_variance   | 0.343        |
|    learning_rate        | 0.0003       |
|    loss                 | 29.5         |
|    n_updates            | 80           |
|    policy_gradient_loss | -0.00381     |
|    value_loss           | 76           |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 955          |
|    iterations           | 10           |
|    time_elapsed         | 21           |
|    total_timesteps      | 20480        |
| train/                  |              |
|    approx_kl            | 0.0026272014 |
|    clip_fraction        | 0.0136       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.576       |
|    explained_variance   | 0.115        |
|    learning_rate        | 0.0003       |
|    loss                 | 122          |
|    n_updates            | 90           |
|    policy_gradient_loss | -0.00157     |
|    value_loss           | 107          |
------------------------------------------





<stable_baselines3.ppo.ppo.PPO at 0x2d49081cfc8>

Save and reload model

PPO_Path = os.path.join('Taining','Save Models','PPO_Model_CartPole')
model.save(PPO_Path)
del model
model = PPO.load(PPO_Path,env = env)
model.learn(total_timesteps = 10000)
Logging to Training\Logs\PPO_5
-----------------------------
| time/              |      |
|    fps             | 1823 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 1237         |
|    iterations           | 2            |
|    time_elapsed         | 3            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0072634863 |
|    clip_fraction        | 0.0615       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.553       |
|    explained_variance   | 0.354        |
|    learning_rate        | 0.0003       |
|    loss                 | 36.4         |
|    n_updates            | 110          |
|    policy_gradient_loss | -0.00478     |
|    value_loss           | 112          |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 1124         |
|    iterations           | 3            |
|    time_elapsed         | 5            |
|    total_timesteps      | 6144         |
| train/                  |              |
|    approx_kl            | 0.0051936586 |
|    clip_fraction        | 0.0519       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.563       |
|    explained_variance   | 0.576        |
|    learning_rate        | 0.0003       |
|    loss                 | 121          |
|    n_updates            | 120          |
|    policy_gradient_loss | -0.00451     |
|    value_loss           | 102          |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 1074         |
|    iterations           | 4            |
|    time_elapsed         | 7            |
|    total_timesteps      | 8192         |
| train/                  |              |
|    approx_kl            | 0.0016217632 |
|    clip_fraction        | 0.00737      |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.55        |
|    explained_variance   | 0.323        |
|    learning_rate        | 0.0003       |
|    loss                 | 108          |
|    n_updates            | 130          |
|    policy_gradient_loss | -0.00079     |
|    value_loss           | 116          |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 1046         |
|    iterations           | 5            |
|    time_elapsed         | 9            |
|    total_timesteps      | 10240        |
| train/                  |              |
|    approx_kl            | 0.0019479269 |
|    clip_fraction        | 0.0187       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.544       |
|    explained_variance   | 0.659        |
|    learning_rate        | 0.0003       |
|    loss                 | 64.3         |
|    n_updates            | 140          |
|    policy_gradient_loss | -0.00166     |
|    value_loss           | 93           |
------------------------------------------





<stable_baselines3.ppo.ppo.PPO at 0x2d4fb413e48>

Evaluating

evaluate_policy(model,env,n_eval_episodes=10,render=True)
C:\ProgramData\Anaconda3\envs\uav\lib\site-packages\stable_baselines3\common\evaluation.py:69: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.
  UserWarning,





(200.0, 0.0)
env.close()

Test model

episodes = 5
for episode in range(1,episodes+1):
    obs = env.reset()
    done = False
    score = 0
    while not done:
        env.render()
        action,_ = model.predict(obs) # predict
        obs,reward,done,info = env.step(action)
        score += reward
    print('episode:{} score:{}'.format(episode,score))
# env.close()
episode:1 score:[200.]
episode:2 score:[200.]
episode:3 score:[200.]
episode:4 score:[200.]
episode:5 score:[200.]
env.close()
obs = env.reset()
model.predict(obs) #返回动作和下一个状态
(array([0], dtype=int64), None)
env.action_space.sample()
1

viewing logs in tensorboard

training_log_path = os.path.join(log_path,'PPO_1') #序号代表按照顺序训练的几次模型
!tensorboard --logdir={
   training_log_path} # localhost:6006
2021-12-07 15:28:43.632786: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'cudart64_100.dll'; dlerror: cudart6
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值