tensorflow中的tf.gradients()函数接口如下:
tf.gradients(ys,
xs,
grad_ys=None,
name="gradients",
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None)
该函数用于计算ys相对于xs的梯度,返回值是一个tensor,维度与xs保持一致。
例如:
当ys=[y1,y2], xs=[x1,x2,x3] 时,ys相对于xs的梯度按下式计算:
代码验证:
import tensorflow as tf
x1 = tf.get_variable('w1', shape=[3])
x2 = tf.get_variable('w2', shape=[3])
x3 = tf.get_variable('w3', shape=[3])
x4 = tf.get_variable('w4', shape=[3])
y1 = x1 + x2+ x3
y2 = x3 + x4
#grads = tf.gradients([y1, y2], [x1, x2, x3, x4], grad_ys=[tf.convert_to_tensor([2.,2.,3.]),tf.convert_to_tensor([3.,2.,4.])])
grads = tf.gradients([y1,y2], [x1, x2, x3, x4])
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(grads))
输出为:
为什么 结果有三列?因为前面提到了梯度与xs维度一致。
grad_ys代表对ys的梯度,如果不为空,则采用链式法则来计算对xs的梯度,即
代码验证如下:
import tensorflow as tf
x1 = tf.get_variable('w1', shape=[3])
x2 = tf.get_variable('w2', shape=[3])
x3 = tf.get_variable('w3', shape=[3])
x4 = tf.get_variable('w4', shape=[3])
y1 = x1 + x2+ x3
y2 = x3 + x4
grads = tf.gradients([y1, y2], [x1, x2, x3, x4], grad_ys=[tf.convert_to_tensor([2.,2.,3.]),tf.convert_to_tensor([3.,2.,4.])])
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(grads))
结果如下:
到此处还一切顺利,因为以上两个例子都是张量相加的情况下求得梯度,比较简单。可是当我想验证张量相乘的情况下求梯度的时候就发现问题了,原来按照下式链式求导的时候一直以为 “*” 号代表按元素相乘(在张量相加求梯度中确实如此)。
然而在张量相乘求梯度的情况下并非如此,以下是我第一次用代码验证的情况:
import tensorflow as tf
x1= tf.Variable([[3.,4.]])
y1 = tf.matmul(x2, [[2.],[1.]])
grads = tf.gradients(y1,x1,grad_ys=tf.convert_to_tensor([[1.,2.]]))
with tf.Session() as sess:
tf.global_variables_initializer().run()
re = sess.run(grads)
print(sess.run(re))
y1对x1求梯度结果为[[3.,4.]] ,按理说与[[1.,2.]] 是可以进行按元素相乘的,结果程序报错:
后来将grad_ys=tf.convert_to_tensor([[2.]]))才解决这个问题,原来是grad_ys要与ys的维度相同才行。
import tensorflow as tf
x1= tf.Variable([[3.,4.]])
y1 = tf.matmul(x2, [[2.],[1.]])
grads = tf.gradients(y1,x1,grad_ys=tf.convert_to_tensor([[2.]]))
with tf.Session() as sess:
tf.global_variables_initializer().run()
re = sess.run(grads)
print(sess.run(re))
得出正确结果:
参考文献[1]:https://blog.csdn.net/weixin_35364049/article/details/82220857