Highway-env和stable_baselines3的RL环境配置注意事项

本文讨论conda环境下的RL实验相关配置。

首先是在conda中创建环境。这里需要注意,因为stable_baselines3需要python3.8。如果采用了3.7及以下版本,则会在后期报错:

ValueError: unsupported pickle protocol: 5

无法打开pickle文件。

(1)首先创建虚拟环境。

conda create -n XXX python=3.8

python>=3.8会自带pickle。

(2)激活并进入该虚拟环境。

activate XXX

(3)安装相关环境依赖包:gym,pygame,stable_baselines3,pytorch。

关于pytorch,需要根据硬件条件进行配置。最好pytorch>=1.13.0。

# CUDA 11.6
conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.6 -c pytorch -c nvidia
# CUDA 11.7
conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia
# CPU Only
conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 cpuonly -c pytorch

然后安装其它包: 

pip install gym
pip install pygame
pip install stable-baselines3
pip install --user git+https://github.com/eleurent/highway-env
or 
pip install highway-env

这里可能会存在找不到stable_baselines3模块的情况,一个简单暴力的解决方案是直接把stable_baselines3模块放在文件的同一目录下。(如果没有出现这个情况,请忽略) 

后期可能需要用到moviepy来做记录,可以也提前安装好。

pip install moviepy

 之后可能会出现import失败报错,尝试重装pillow。这一部分稍微有些玄学,得多试,下面是三种方法:

pip uninstall Pillow
pip install Pillow
python -m pip uninstall Pillow
python -m pip install Pillow
conda install Pillow

最后运行相关的样例代码,如下:

import gymnasium as gym
import highway_env
from stable_baselines3 import DQN


env = gym.make("highway-fast-v0")
model = DQN('MlpPolicy', env,
              policy_kwargs=dict(net_arch=[256, 256]),
              learning_rate=5e-4,
              buffer_size=15000,
              learning_starts=200,
              batch_size=32,
              gamma=0.8,
              train_freq=1,
              gradient_steps=1,
              target_update_interval=50,
              verbose=1,
              tensorboard_log="highway_dqn/")
model.learn(int(2e4))
model.save("highway_dqn/model")

# Load and test saved model
env = gym.make("highway-fast-v0", render_mode='rgb_array')
model = DQN.load("highway_dqn/model")
while True:
    done = truncated = False
    obs, info = env.reset()
    while not (done or truncated):
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
        env.render()

如果是CPU上跑上述代码,可能会花费20分钟左右,算是一个可以接受的时间长度了。下面是个人电脑配置,仅供参考。

也可以考虑改变model.learn()里面的参数大小,进行相应的观察。

可以得到最终结果(动图)。

Traceback (most recent call last): File "op2_walk_improved.py", line 5, in <module> from stable_baselines3 import PPO File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\stable_baselines3\__init__.py", line 3, in <module> from stable_baselines3.a2c import A2C File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\stable_baselines3\a2c\__init__.py", line 1, in <module> from stable_baselines3.a2c.a2c import A2C File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\stable_baselines3\a2c\a2c.py", line 4, in <module> from gymnasium import spaces File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\gymnasium\__init__.py", line 12, in <module> from gymnasium.envs.registration import ( File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\gymnasium\envs\__init__.py", line 382, in <module> load_plugin_envs() File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\gymnasium\envs\registration.py", line 600, in load_plugin_envs fn = plugin.load() File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\importlib_metadata\__init__.py", line 209, in load module = import_module(match.group(&#39;module&#39;)) File "C:\Users\86151\anaconda3\envs\webots\lib\importlib\__init__.py", line 127, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\highway_env\__init__.py", line 19, in <module> from highway_env.envs.common.abstract import MultiAgentWrapper File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\highway_env\envs\__init__.py", line 1, in <module> from highway_env.envs.highway_env import * File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\highway_env\envs\highway_env.py", line 6, in <module> from highway_env.envs.common.abstract import AbstractEnv File "C:\Users\86151\anaconda3\envs\webots\lib\site-packages\highway_env\envs\common\abstract.py", line 11, in <module>
最新发布
03-18
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值