【FlappyBird小游戏】编写AI逻辑(五)——搭建计算图和学习机制

本文隶属于一个完整小项目,建议读者按照顺序阅读。

本文仅仅展示最关键的代码部分,并不会列举所有代码细节,相信具备RL基础的同学理解起来没有困难。

全部的AI代码可以在【Python小游戏】用AI玩Python小游戏FlappyBird【源码】中找到开源地址。

如果本文对您有帮助,欢迎点赞支持!


文章目录

一、定义计算图

1、定义评估网络的输入输出

2、定义目标网络的输入输出

3、定义两个网络的参数更新操作

二、定义学习机制

1、从经验重放池中进行批采样

2、根据批采样数据使用单步Q-learning公式计算目标Q值

3、将目标Q值和状态等输入评估网络,训练更新评估网络和目标网络


一、定义计算图

1、定义评估网络的输入输出

我们的评估网络输入是某个时刻的状态,输出是该状态下可以选择的每个动作的Q-eval。

我们的评估网络损失是该时刻下的目标Q值Q_target和当前Q值Q_eval的均方误差

其中当前Q值Q_eval就是评估网络的输出,目标Q值需要根据Q-Learning机制求得,所以这里先定义为计算图的输入:

# 定义目标网络的输入输出
self.s = tf.placeholder(dtype="float", shape=[None, 80, 80, 4], name='s') # 网络的状态输入
self.q_target = tf.placeholder(dtype="float", shape=[None,self.n_actions], name='q_target')

我们的评估网络定义如下:

        # 定义评估网络的输入输出
        with tf.variable_scope('eval_net'):
            with tf.variable_scope('output'):
                self.q_eval = self._define_cnn_net(self.s,c_names=['eval_net_params', tf.GraphKeys.GLOBAL_VARIABLES])
            with tf.variable_scope('loss'):
                self.cost = tf.reduce_mean(tf.squared_difference(self.q_target,self.q_eval))# 使用预测奖励值与当前位置的奖励平方差来获得损失值
                tf.summary.scalar("loss", self.cost)  # 使用TensorBoard监测该变量
            with tf.variable_scope('train'):
                self.trainStep = tf.train.AdamOptimizer(self.learn_rate).minimize(self.cost)

2、定义目标网络的输入输出

因为我们的nature dqn算法 使用的是双网络机制,该机制降低了数据之间的关联,可以使算法对数据的Q学习更加健壮,所以我们需要定义一个目标网络.

该网络和评估网络结构一样,但是参数不需要反向传播更新,所以不需要定义损失值和训练操作:

我们的目标网络定义如下:

        # 定义目标网络的输入输出
        with tf.variable_scope('target_net'):
            with tf.variable_scope('output'):
                self.q_next = self._define_cnn_net(self.s, c_names=['target_net_params', tf.GraphKeys.GLOBAL_VARIABLES])

3、定义两个网络的参数更新操作

因为目标网络的参数是定期从当前网络中复制 而来,所以我们需要继续在计算图中定义参数更新操作:

t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]

二、定义学习机制

算法的学习机制基本就是三步操作:

1、从经验重放池中进行批采样

2、根据批采样数据使用单步Q-learning公式计算目标Q值

3、将目标Q值和状态等输入评估网络,训练更新评估网络和目标网络

1、从经验重放池中进行批采样

批采样代码和我们的经验重播池定义紧密相关,这里我们的批采样代码如下:

minibatch =self.memory.sample(self.batch_size) # 获得一个batch的图片信息
state_batch = [data[0] for data in minibatch] # 获得状态信息, [80, 80, 4]
action_batch = [np.argmax(data[1]) for data in minibatch]# 获得动作信息,即向上的索引值为[1, 0], 向下的为[0, 1]
reward_batch = [data[2] for data in minibatch]# 获取奖励信息
nextState_batch = [data[3] for data in minibatch]# 获得下一状态信息, [80, 80, 4]
terminal_batch = [data[4] for data in minibatch]# 获得是否合格, [80, 80, 4]

2、根据批采样数据使用单步Q-learning公式计算目标Q值

(1)使用目标网络获取后继状态的后继Q值:

 q_next_batch = self.q_next.eval(feed_dict={self.s: nextState_batch})

(2)根据单步Q-Learning公式计算批数据的目标Q值:

# 根据Q-Learning机制计算目标Q值
q_eval_batch = self.sess.run(self.q_eval, {self.s: state_batch})  # 使用评估网络获取当前状态的当前Q值
q_target_batch = q_eval_batch.copy()  # 目标Q值和当前Q值具有相同的矩阵结构,所以直接复制
for i in range(0, self.batch_size):
    terminal = terminal_batch[i]
    if terminal:
         q_target_batch[i, action_batch[i]] = reward_batch[i]  + self.gamma * np.max(q_next_batch[i])  # 目标Q值=当前奖励+折扣因子*后继Q值
    else:
         q_target_batch[i, action_batch[i]] = reward_batch[i]  # 目标Q值=当前奖励

3、将目标Q值和状态等输入评估网络,训练更新评估网络和目标网络

(1)训练更新评估网络:使用计算图直接执行反向传播和损失计算:

_, cost = self.sess.run([self.trainStep, self.cost],
                                    feed_dict={self.s: state_batch,
                                               self.q_target: q_target_batch})

(2)定期更新目标网络:使用计算图直接执行参数更新操作(硬更新机制):

 if self.learn_step_counter % self.replace_target_iter == 0:
       self.sess.run(self.replace_target_op)
       print('\ntarget_net的参数被更新\n')

 

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

魔法攻城狮MRL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值