数据不均衡--Focal loss Tensorflow实现
def focal_loss(prediction_tesnsor,target_tensor,weights = None,alpha = 0.25,gamma=2):
"""
FL = -alpha *(z-p)^gamma*log(p) - (1-alpha)*p^gamma *log(1-p)
which alpha = 0.25,gamma = 2,p = sigmoid(x),z = target_tensor
"""
sigmoid_p = tf.nn.sigmoid(prediction_tesnsor)
zeros = array_ops.zeros.like(sigmoid_p,dtype = sigmoid_p.type)
#对比正标签
pos_p_sub = array_ops.where(target_tensor > zeros,target_tensor - sigmoid_p,zeros)
#对于负标签
neg_p_sub = array_ops.where(target_tensor > zeros,zeros,sigmoid_p)
fl_loss = -alpha * (pos_p_sub**gamma)*tf.log(tf.clip_by_value(sigmoid_p,1e-8,1.0))\
- (1-alpha) * (neg_p_sub**gamma)*tf.log(tf.clip_by_value(1-sigmoid_p,1e-8,1.0))
return tf.reduce_sum(fl_loss)