Tensorflow自定义激活函数/函数/梯度
(对于Tensorflow 1.x)
最近刚做完一个paper,需要自定义激活函数,记录一下心得,顺便帮助下有需要的小伙伴。大刀阔斧,直接上解决方案:
1、对于分段(激活)函数,代码分开写
2、使用自带自定义梯度
详解
Tensorflow是自动求导(不懂就百度),因此我们不需要定义梯度,但大家可能会遇到和我一样的问题(在训练模型的时候loss爆炸),所以大家才会来查吧。
自定义激活函数/函数直接定义就可以,比如:
output = tf.exp(input)
output = tf.log(input)
但为什么有时候会梯度爆炸?
因为激活函数大多是参照relu进行修改,故大多是分段函数,分段函数在tensorflow中使用
tf.where(tf.greater(input, [0.0]),function1,function2)
funtion1计算大于0的数,function2计算小于等于0的数,但这就导致我构造的激活函数loss爆炸。原因不详,猜测是先计算所有输入都参与function1和function2的计算。
我使用了tensorflow定义swish的例子定义函数:
def _swish_shape(op):
return [op.inputs[0].shape]