tensorflow的特点是,只需要通过Tensor变量构建Graph,和相应的优化目标 loss(也可以看作Graph的一部分),当调用优化器去minimize loss时,优化器会根据loss中所涉及的变量,自动进行BP,对所有的相关变量进行参数更新。
在 tensorflow 多Agent 灵活保存、更新Graph的各部分参数(tf.variable_scope(), tf.get_collection(), tf.train.Saver()中,我们已经讨论过,当复杂任务需要控制部分网络在部分loss下进行更新的方法。当网络结构不太复杂时,还有一种方式也能达到类似的目的,这就是 tf.stop_grandient.
顾名思义,tf.stop_grandient的作用就是阻断当前节点的梯度传播,使得梯度从loss反向传播到该节点后就停止传播,不再对更之前的变量进行更新。最常见的例子是在DQN中:
...
with tf.variable_scope('q_target'):
q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_') # shape=(None, )
self.q_target = tf.stop_gradient(q_target)
...
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))
其原理类似于复制相应Tensor的值,生成一个新的Tensor,新Tensor的值永远保持和原Tensor一致,但新Tensor不依赖于任何参数变量。
在DQN中,self.q_target 和 self.q_eval_wrt_a 的值都是网络生成的,但此处self.q_target的作用类似label,真正需要更新的是self.q_eval_wrt_a的值,tf.stop_grandient正好满足了我们的需求。
写作过程参考了 关于tf.stop_gradient的使用及理解,特此感谢。