在tensorflow中,tf.where 和tf.greater的组合相当于是一个分段函数,其梯度传播主要取决于里面不同condition,以下是一个简单示例:当x1小于0时,y1对x1的梯度为0,梯度在这里停止传播;当x2大于0时,y2对于x2的梯度为1;
# -*- coding: utf-8 -*-
"""
Created on Wed Dec 5 14:58:56 2018
@author: lunxi.yuan
"""
import tensorflow as tf
with tf.Graph().as_default() as graph:
x1 = tf.Variable(-1)
x2 = tf.Variable(3)
y1=tf.where(tf.greater(x1,0),x1,0)
y2 = tf.where(tf.greater(x2,0),x2,0)
grad_1 = tf.gradients(y1, x1)
grad_2 = tf.gradients(y2, x2)
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(grad_1))
print(sess.run(grad_2))