参考了一些文章,针对OpenAI gym环境,使用tf2.x实现了DQN算法;加上了一些没有太大必要(?)的小功能,比如:自动保存视频,保存训练日志从而利用TensorBoard实现数据可视化,保存和读取训练记忆,保存和读取DQN权重。适用于CartPole环境,只需少量改动也可以用于MountainCar环境。代码下载及本文参考见文末。
本篇目录
- 代码框架
- 主循环
- 单局游戏
- 自定义reward
- 自动录制视频
- DQN算法主体
- 神经网络
- DQN
- 效果
- TensorBoard可视化
- 关于Reward
- 代码下载
- 参考
CartPole环境
代码框架
主循环
-------------初始化---------------
- 创建Gym环境实例
- 设置DQN参数
- 创建DQN实例
- 载入DQN记忆(可选)
- 载入DQN权重(可选)
- 创建tensorflow summary writer
- 设置训练参数
--------------训练----------------
- 循环N次
| 运行一局,得到该局内的reward和losses
| 记录数据到summary
| 显示数据到命令行
- 结束循环
--------------结束----------------
- 保存DQN记忆
- 保存DQN权重
- 录制视频
- 关闭环境
代码:
from cart_pole import MyModel
from cart_pole import DQN
from cart_pole import play_game
from cart_pole import make_video
import numpy as np
import tensorflow as tf
import gym
import os
import datetime
from statistics import mean
from gym import wrappers
def main():
####################初始化#####################
env = gym.make('CartPole-v0')
gamma = 0.9
num_states = len(env.observation_space.sample())
num_actions = env.action_space.n
hidden_units = [20, 20]
max_experiences = 2000
min_experiences = 1000
batch_size = 32
lr = 0.01
e_greedy = 0.9
e_greedy_increment = 1.001
replace_target_iter = 50
DQN_ = DQN(num_states, num_actions, hidden_units, gamma, max_experiences,
min_experiences, batch_size, lr,e_greedy, replace_target_iter,e_greedy_increment)
DQN_.loadModel()
DQN_.loadMemory()
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/dqn/' + current_time
summary_writer = tf.summary.create_file_writer(log_dir)
N = 10000 # 总训练轮次
Dispaly_interval = 100
total_rewards = np.empty(N)
####################主循环#####################
for n in range(N):
total_reward, losses = play_game(env, DQN_)
total_rewards[n] = total_reward
avg_rewards = total_rewards[max(0, n - Dispaly_interval):(n + 1)].mean()
with summary_writer.as_default():
tf.summary.scalar('episode reward', total_reward, step=n)
tf.summary.scalar('running avg reward(100)', avg_rewards, step=n)
tf.summary.scalar('average loss)', losses, step=n)
if n % Dispaly_interval == 0:
print("episode:", n, "eps:", DQN_.epsilon, "avg reward (last ",Dispaly_interval,"):", avg_rewards,
"episode loss: ", losses)
####################结束##############