转载网址:http://blog.csdn.net/u012436149/article/details/53905797
gradient
tensorflow
中有一个计算梯度的函数tf.gradients(ys, xs)
,要注意的是,xs
中的x
必须要与ys
相关,不相关的话,会报错。
代码中定义了两个变量w1
, w2
, 但res
只与w1
相关
#wrong import tensorflow as tf w1 = tf.Variable([[1,2]]) w2 = tf.Variable([[3,4]]) res = tf.matmul(w1, [[2],[1]]) grads = tf.gradients(res,[w1,w2]) with tf.Session() as sess: tf.global_variables_initializer().run() re = sess.run(grads) print(re)
错误信息
TypeError: Fetch argument None has invalid type# right import tensorflow as tf w1 = tf.Variable([[1,2]]) w2 = tf.Variable([[3,4]]) res = tf.matmul(w1, [[2],[1]]) grads = tf.gradients(res,[w1]) with tf.Session() as sess: tf.global_variables_initializer().run() re = sess.run(grads) print(re) # [array([[2, 1]], dtype=int32)]
tf.stop_gradient()
阻挡节点
BP
的梯度import tensorflow as tf w1 = tf.Variable(2.0) w2 = tf.Variable(2.0) a = tf.multiply(w1, 3.0) a_stoped = tf.stop_gradient(a) # b=w1*3.0*w2 b = tf.multiply(a_stoped, w2) gradients = tf.gradients(b, xs=[w1, w2]) print(gradients) #输出 #[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
可见,一个
节点
被stop
之后,这个节点上的梯度,就无法再向前BP
了。由于w1
变量的梯度只能来自a
节点,所以,计算梯度返回的是None
。a = tf.Variable(1.0) b = tf.Variable(1.0) c = tf.add(a, b) c_stoped = tf.stop_gradient(c) d = tf.add(a, b) e = tf.add(c_stoped, d) gradients = tf.gradients(e, xs=[a, b]) with tf.Session() as sess: tf.global_variables_initializer().run() print(sess.run(gradients)) #输出 [1.0, 1.0]
虽然
c
节点被stop
了,但是a,b
还有从d
传回的梯度,所以还是可以输出梯度值的。import tensorflow as tf w1 = tf.Variable(2.0) w2 = tf.Variable(2.0) a = tf.multiply(w1, 3.0) a_stoped = tf.stop_gradient(a) # b=w1*3.0*w2 b = tf.multiply(a_stoped, w2) opt = tf.train.GradientDescentOptimizer(0.1) gradients = tf.gradients(b, xs=tf.trainable_variables()) tf.summary.histogram(gradients[0].name, gradients[0])# 这里会报错,因为gradients[0]是None #其它地方都会运行正常,无论是梯度的计算还是变量的更新。总觉着tensorflow这么设计有点不好, #不如改成流过去的梯度为0 train_op = opt.apply_gradients(zip(gradients, tf.trainable_variables())) print(gradients) with tf.Session() as sess: tf.global_variables_initializer().run() print(sess.run(train_op)) print(sess.run([w1, w2]))