2021-06-02

# tensorflow自定义操作并附加自定义导数

## 说明

tensorflow最初给的例程中,虽然给了使用C++代码创建操作,并附加相应导数的能力,但是对于定义一个已有的复杂操作,并修改其导数的方法还不甚明晰。

本文结合@function.Defun() 和 @tf.RegisterGradient()以及g.gradient_override_map()来实现相关功能

## 示例

import tensorflow.compat.v1 as tf 
from tensorflow.python.framework import function

tf.compat.v1.disable_eager_execution() 

@function.Defun(tf.float32)
def customOp(x):
    return tf.identity(x * x * 10 - 2)

@tf.RegisterGradient("CustomOpGrad")
def _clip_grad(unused_op, grad):
  return grad * 0.1
 
x = tf.Variable([3.0], dtype=tf.float32)
 
g = tf.get_default_graph()
with g.gradient_override_map({customOp.name: "CustomOpGrad"}):
  y = customOp(x)

dydx = tf.gradients(y, x)

print("operation :", customOp.name)
print("x :", x)
print("y :", y)
print("dydx :", dydx)
 
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print("dydx:", sess.run(dydx)[0])

## 结果

operation : customOp_yk92PMBQ5p0
x : <tf.Variable 'Variable:0' shape=(1,) dtype=float32>
y : Tensor("customOp_yk92PMBQ5p0:0", dtype=float32)
dydx : [<tf.Tensor 'gradients/customOp_yk92PMBQ5p0_grad/mul:0' shape=(1,) dtype=float32>]
dydx: [0.1]

## 说明

1. @function.Defun 注册customOp这个op.

需要注意的是,其name属性并不是叫customOp

2. @tf.RegisterGradient 注册CustomOpGrad这个op的求梯度操作

需要注意的是,其name属性就叫CustomOpGrad

3. g.gradient_override_map 将 customOp的求梯度操作替换成了CustomOpGrad

我看有些文章,给的示例都是tf.identity的

# 一般的示例
with g.gradient_override_map({"Identity": "CustomOpGrad"}):
  y = tf.identity(x, name="Identity")

# 对此示例不成立
# 只会影响tf.identity
# 不会影响tf.square
with g.gradient_override_map({"Identity": "CustomOpGrad"}):
  y = tf.identity(tf.square(x), name="Identity")

所以g.gradient_override_map这里填入的是定义的操作名,而不是操作之后的张量名

y就是一个张量,名字可以配置成Identity,但是我们想要的功能并没有实现

## 参考

这个可以按关键字自己百度,我是散的拼成+试错的

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值