@tf.custom_gradient

@tf.custom_gradient

初衷

网上资料较少,而且官方文档比较ambigious(也许有误),花了比较久的时间搞懂这个修饰器,记此贴防止大家走弯路。

官方文档
参考文档

介绍

@tf.custom_gradient

装饰器允许控制对梯度的一连串操作,这样做的好处是对梯度操作提供一种更有效率更稳定方式。

考虑一种情况
在这里插入图片描述
由于数值不稳定性,x=100处的梯度( ▽ f = ∂ f ∂ x i ⃗ \bigtriangledown f=\frac{\partial f}{\partial x}\vec i f=xfi )由函数得到的值为 N a n Nan Nan

在这里插入图片描述

解决方法

使用@custom_gradient,梯度表达式可以被解析简化,以提供数值稳定性
在这里插入图片描述
可以推断@tf.custom_gradient的
  args为 x x x,
  returns为 y , ∂ y ∂ x y,\frac{\partial y}{\partial x} y,xy的函数形式
一方面调用y=log1exp(x),可以得到y=y
另一方面调用grady=gradient(y,x),可以得到grady= ∂ y ∂ x \frac{\partial y}{\partial x} xy

于是对于二阶导
只需要定义一阶导的嵌套形式,使用@custom_gradient修饰一阶导并使其返回y对x的一阶导以及y对x二阶导对应的函数

代码如下
@tf.custom_gradient
def log1pexp2(x):
    e = tf.exp(x)
    y = tf.math.log(1 + e)
    x_grad = 1 - 1 / (1 + e)
    def first_order_gradient(dy):
        @tf.custom_gradient
        def first_order_custom(unused_x):
            def second_order_gradient(ddy):
                # Let's define the second-order gradient to be (1 - e)
                return ddy * (1 - e) 
            return x_grad, second_order_gradient
        return dy * first_order_custom(x)
    return y, first_order_gradient

以上二阶导不是真实的二阶导(为了便于检测)

测试代码如下
import tensorflow as tf

@tf.custom_gradient
def log1pexp2(x):
    e = tf.exp(x)
    y = tf.math.log(1 + e)
    x_grad = 1 - 1 / (1 + e)
    def first_order_gradient(dy):
        @tf.custom_gradient
        def first_order_custom(unused_x):
            def second_order_gradient(ddy):
                # Let's define the second-order graidne to be (1 - e)
                return ddy * (1 - e) 
            return x_grad, second_order_gradient
        return dy * first_order_custom(x)
    return y, first_order_gradient

x1 = tf.constant(1.)
y1 = log1pexp2(x1)
dy1 = tf.gradients(y1, x1)
ddy1 = tf.gradients(dy1, x1)

x2 = tf.constant(100.)
y2 = log1pexp2(x2)
dy2 = tf.gradients(y2, x2)
ddy2 = tf.gradients(dy2, x2)

with tf.Session() as sess:
    print('x=1, dy1:', dy1[0].eval(session=sess))
    print('x=1, ddy1:', ddy1[0].eval(session=sess))
    print('x=100, dy2:', dy2[0].eval(session=sess))
    print('x=100, ddy2:', ddy2[0].eval(session=sess))
运行结果
x=1, dy1: 0.7310586
x=1, ddy1: -1.7182817
x=100, dy2: 1.0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值