tf.stop_grandient

tensorflow的特点是,只需要通过Tensor变量(可以看作是自己给出的输入数据)构建Graph,和相应的优化目标 loss(也可以看作Graph的一部分),当调用优化器去minimize loss时,优化器会根据loss中所涉及的变量,自动进行BP,对所有的相关变量进行参数更新。

简单的模型可以直接从输入X开始,一层层地设置变量和operation,最终得到输出\hat{Y},并和label真实值Y,一起计算出loss,然后调用优化器最小化Loss。

然而,复杂的模型,并不希望像上面一样直接对模型进行端到端的更新,可能涉及到多个优化目标loss,这时候就不能像上面一样简单的直接对Graph内的全体参数直接进行梯度下降更新,而是需要灵活控制各部分参数。

主要依赖的函数是tf.variable_scope() 和 tf.get_collection()。

import tensorflow as tf
n1 = 50
input_dim = 100
output_dim = 1
X = tf.placeholder(tf.float32, [None, input_dim], 'X')
Y1 = tf.placeholder(tf.float32, [None, output_dim], 'Y1')
Y2= tf.placeholder(tf.float32, [None, output_dim], 'Y2')

with tf.variable_scope('Agent'):
    with tf.variable_scope('layer1'):
        w1 = tf.get_variable('w1', [input_dim, n1], trainable=True)
        b1 = tf.get_variable('b1', [1, n1], trainable=True)
        s = tf.matmul(X, w1) + b1
    with tf.variable_scope('layer2_1'):
        w21 = tf.get_variable('w21', [n1, output_dim], trainable=True)
        b21 = tf.get_variable('b21', [1, output_dim], trainable=True)
        y1 = tf.matmul(s, w21) + b21
    with tf.variable_scope('layer2_2'):
        w22 = tf.get_variable('w22', [n1, output_dim], trainable=True)
        b22 = tf.get_variable('b22', [1, output_dim], trainable=True)
        y2 = tf.matmul(s, w22) + b22

如上所示,我们定义了一个图’Agent’,所有的变量都在‘Agent’内,‘Agent-ayer1’内为第一层变量,其输出s被两个互相独立的层’Agent-layer2_1’和‘Agent-layer2_2’共享。

使用tf.get_collection()把相应scope内的变量取出来,输出为这些变量构成的列表:

agent_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Agent')
layer1_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Agent/layer1')
layer21_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Agent/layer2_1')
layer22_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Agent/layer2_2')

有了上面的准备工作,我们就可以进行一些灵活的操作:选择其中的部分参数进行训练;对部分训练好的参数进行保存。

通过 var_list 选择其中的部分参数进行训练:

loss1 = tf.losses.mean_squared_error(labels=Y1, predictions=y1)
loss2 = tf.losses.mean_squared_error(labels=Y2, predictions=y2)

train1 = tf.train.AdamOptimizer(lr).minimize(loss1, var_list=layer1_params+layer21_params)
train2 = tf.train.AdamOptimizer(lr).minimize(loss2, var_list=layer1_params+layer22_params)

或者,使用tf.train.Saver()对相应部分参数进行保存和读取:

sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver1 = tf.train.Saver(agent_params)
saver2 = tf.train.Saver(layer1_params)
saver3 = tf.train.Saver(layer21_params)
saver4 = tf.train.Saver(layer22_params)
save_path1 = saver1.save(sess,"./models/agent.ckpt")
save_path2 = saver2.save(sess,"./models/layer1.ckpt")
save_path3 = saver3.save(sess,"./models/layer21.ckpt")
save_path4 = saver4.save(sess,"./models/layer22.ckpt")
saver1.restore(sess,"./models/agent.ckpt")
saver2.restore(sess,"./models/layer1.ckpt")
saver3.restore(sess,"./models/layer21.ckpt")
saver4.restore(sess,"./models/layer22.ckpt")

tensorflow 多Agent 灵活保存、更新Graph的各部分参数(tf.variable_scope(), tf.get_collection(), tf.train.Saver())_南阁风起的博客-CSDN博客

当网络结构不太复杂时,还有一种方式也能达到类似的目的,这就是 tf.stop_grandient即阻断当前节点的梯度传播,使得梯度从loss反向传播到该节点后就停止传播,不再对更之前的变量进行更新。

下面代码就是DQN的常规写法了,在DQN中有两个两个网络,一个eval net,一个target net。对eval net的参数更新是通过MSE + GD来更新的,而MSE的计算将用到target net对下一状态的估值,通常的做法是对eval net设置一个placeholder,也即引入一个输入,用这个placeholder计算loss。

placeholder输入的本身就是计算好了的q_target,也就是说我们通过feed_dict,将对target net进行计算得到的一个q_target Tensor传入placeholder中,当做常量来对待,我们可以把一次计算(eval/run)看作是一次截图,得到当时各个op的值。这样的话,我们对于eval net中loss的反传就不会影响到target net了。

就是对self.q_target赋值,就不存在更新了

...
self.q_target = tf.placeholder(tf.float32, [None, self.n_actions], name='Q_target')  # for calculating loss
...
with tf.variable_scope('loss'):
            self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval))
with tf.variable_scope('train'):
            self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)
...
...
with tf.variable_scope('q_target'):
            q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_')    # shape=(None, )
            self.q_target = tf.stop_gradient(q_target)
...
with tf.variable_scope('loss'):
            self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))

        第二种方法中直接拿target net中的q_target这个op来计算eval net中的loss显然是不妥的,因为我们对loss进行反传时将会影响到target net,这不是我们想看到的结果。所以,这里引入stop_gradient来对从loss到target net的反传进行截断,换句话说,通过self.q_target = tf.stop_gradient(q_target),将原本为TensorFlow计算图中的一个op(节点)转为一个常量self.q_target,这时候对于loss的求导反传就不会传到target net去了。

tf.stop_grandient 用法_南阁风起的博客-CSDN博客 

关于tf.stop_gradient的使用及理解_微丶念(小矿工)的博客-CSDN博客_stop_gradient

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值