tensorflow中gradients基本用法

import tensorflow as tf
"""
在反向传播过程中,神经网络需要对 loss 对应的学习参数求偏导(也叫梯度),
该值用来乘以学习率然后更新学习参数使用的。
它是通过 tf.gradients 函数来实现的 
"""
weight = tf.Variable([[1, 2]])
y = tf.matmul(weight, [[9], [10]])
"""
第一个参数:求导公式的结果
第二个参数:要求的偏导的参数
"""
grads = tf.gradients(y, weight)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    grad_val = sess.run(grads)
    print(grad_val)   # [array([[ 9, 10]], dtype=int32)]


# 对多个公式的多个变量求偏导
tf.reset_default_graph()
weight1 = tf.get_variable('weight1', shape=[2])
weight2 = tf.get_variable('weight2', shape=[2])
weight3 = tf.get_variable('weight3', shape=[2])
weight4 = tf.get_variable('weight4', shape=[2])

y1 = weight1 + weight2 + weight3
y2 = weight3 + weight4

# grad_ys 公式的结果
gradients = tf.gradients([y1, y2], [weight1, weight2, weight3, weight4],
                         grad_ys=[tf.convert_to_tensor([1., 2.]), tf.convert_to_tensor([3., 4.])])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gradients))


"""
梯度停止的实现
对于反向传播过程中某种特殊情况需要停止梯度的运算时,
TensorFlow 中提供了 tf.stop_gradients 函数,
被它定义过的节点将没有梯度运算功能
"""

a = weight1 + weight2
a_stopped = tf.stop_gradient(a)
y3 = a_stopped + weight3

gradients1 = tf.gradients(y3, [weight1, weight2, weight3], grad_ys=[tf.convert_to_tensor([1., 2.])])
gradients2 = tf.gradients(y3, [weight3], grad_ys=[tf.convert_to_tensor([1., 2.])])
print(gradients1)  # [None, None, < tf.Tensor 'gradients_1/grad_ys_0:0' shape = (2,) dtype = float32 >]
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    '''
    下面代码会报错
    因为weight1、weight2 的梯度被停止,程序试图去求一个None的梯度,所以报错
    注释掉求 gradients2 就又正确了
    '''
    # print(sess.run(gradients1))
    print(sess.run(gradients2))


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值