上一章:用gym创建一个自定义环境https://blog.csdn.net/WHTU_JZZ/article/details/126955600?spm=1001.2014.3001.5502
1.奖励函数的设计
我们的任务是在如下的表格中找到一条路径,使得从r0到r24的阻力最小。
表中每个格子对应的阻力用numpy的random随机生成:
p = np.random.uniform(0, 1, size=(5, 5))
因此我们将奖励函数设计为:
(1) agent未到达目标点时,奖励为表格对应数的负数;
(2)agnt到达目标点时,奖励为100.
if (self.current_state[1]-self.goal[1]) ** 2 + (self.current_state[0]- self.goal[0]) ** 2 == 0:
done = True
else:
done = False
if done:
reward = 10
else:
reward = EachReward(current_state=self.current_state, rows=self.rows, cols=self.cols)
2.模型的训练与保存
由于动作是离散的,所以选取DQN算法。模型的训练与保存基于pytorch和stable-baselines3
from stable_baselines3 import DQN
import os
from envs.my_env import PathPlanning
from envs.build_problems import p
models_dir = "models/DQN"
model_path = f"{models_dir}/100000.zip"
logdir = "logs"
#创建文件夹
if not os.path.exists(models_dir):
os.makedirs(models_dir)
if not os.path.exists(logdir):
os.makedirs(logdir)
env = PathPlanning()
env.reset()
model = DQN("MlpPolicy", env, verbose=1, tensorboard_log=logdir)#DQN智能体
episodes = 15#训练15轮
for i in range(episodes):
env.reset()
model.learn(total_timesteps=10000, reset_num_timesteps=False, tb_log_name="DQN")#每轮agent走10000个时间步
model.save(f"{models_dir}/{10000*i}")
print(p)
训练时在文件夹中使用命令行打开使用tensorboard --logdir logs查看训练情况:
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.10.0 at http://localhost:6006/ (Press CTRL+C to quit)
将这里的网址复制到浏览器打开,这里是训练完毕之后的:
第一张图时每轮的的时间步数,第二张图是每轮的奖励,可以看到在训练50w步之后奖励稳定在较高的水平,长度稳定在8步。
3. 加载模型
from stable_baselines3 import DQN
from envs.my_env import PathPlanning
from envs.my_env import p
env = PathPlanning()
env.reset()
models_dir = "models/DQN"
models_path = f"{models_dir}/9990000.zip"
logdir = "logs"
model = DQN.load(models_path, env=env)
episodes = 10
print(p)
for episode in range(episodes):
obs = env.reset()
done = False
rewards = 0
while not done:
action, info = model.predict(obs)
obs, reward, done, info = env.step(action)
rewards += reward
print("observation", obs)
print("reward", rewards)
运行结果:
[[0.45284899 0.62989762 0.32944424 0.3818241 0.9641225 ]
[0.44854522 0.57372626 0.75464161 0.83378715 0.91354881]
[0.70371768 0.26211896 0.517719 0.25071574 0.06647013]
[0.51405917 0.62054058 0.17492277 0.47799731 0.94916258]
[0.13165789 0.88561359 0.8383077 0.94061318 0.02174876]]