https://blog.csdn.net/LoseInVain/article/details/83108001
https://github.com/tensorflow/tensorflow/blob/7dd20b844ced19610f8fa67be61d93948563ac43/tensorflow/python/ops/custom_gradient.py
输入
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
@tf.custom_gradient
def DoublySign(x):
def grad(dy):
'''
dy 是从反向而言的上一层的梯度
'''
cond = (x >= -1) & (x <= 1)
zeros = tf.zeros_like(dy)
return tf.where(cond,dy,zeros)
# 这里相当于是在自己手动计算梯度,如果在-1和1之间,将sign的函数梯度修改为1,根据链式法则,梯度为dy*1=dy
# 而其他情况下,梯度为dy*0=0
# tf.where 和 tf.cond 的区别 https://blog.csdn.net/xiadimichen14908/article/details/83592282
return tf.sign(x),grad
x = tf.constant(np.linspace(-2,2,100))
y = DoublySign(x)
grad = tf.gradients(y,x)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(x.eval())
print(y.eval())
print(sess.run(grad))
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
ax = plt.gca() # get current axis 获得坐标轴对象
plt.xlabel('x&