tf.stop_grandient 用法

tensorflow的特点是,只需要通过Tensor变量构建Graph,和相应的优化目标 loss(也可以看作Graph的一部分),当调用优化器去minimize loss时,优化器会根据loss中所涉及的变量,自动进行BP,对所有的相关变量进行参数更新。

tensorflow 多Agent 灵活保存、更新Graph的各部分参数(tf.variable_scope(), tf.get_collection(), tf.train.Saver()中,我们已经讨论过,当复杂任务需要控制部分网络在部分loss下进行更新的方法。当网络结构不太复杂时,还有一种方式也能达到类似的目的,这就是 tf.stop_grandient.

顾名思义,tf.stop_grandient的作用就是阻断当前节点的梯度传播,使得梯度从loss反向传播到该节点后就停止传播,不再对更之前的变量进行更新。最常见的例子是在DQN中:

...
with tf.variable_scope('q_target'):
            q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_')    # shape=(None, )
            self.q_target = tf.stop_gradient(q_target)
...
with tf.variable_scope('loss'):
            self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))

其原理类似于复制相应Tensor的值,生成一个新的Tensor,新Tensor的值永远保持和原Tensor一致,但新Tensor不依赖于任何参数变量。
在DQN中,self.q_target 和 self.q_eval_wrt_a 的值都是网络生成的,但此处self.q_target的作用类似label,真正需要更新的是self.q_eval_wrt_a的值,tf.stop_grandient正好满足了我们的需求。

写作过程参考了 关于tf.stop_gradient的使用及理解,特此感谢。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值