def focal_loss(logits,onehot_labels, gamma=2.0, alpha=4.0):
"""
focal loss for multi-classification
FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)
Notice: logits is probability after softmax
gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper
d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x)
Lin, T.-Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017).
Focal Loss for Dense Object Detection, 130(4), 485–491.
https://doi.org/10.1016/j.ajodo.2005.02.022
:param labels: ground truth labels, shape of [batch_size]
:param logits: model's output, shape of [batch_size, num_cls]
:param gamma:
:param alpha:正样本*alpha 负样本*(1-alpha) alpha越高 正样本训练的就越好,负样本训练的就越差
:return: shape of [batch_size]
"""
epsilon = 1.e-9
logits = tf.convert_to_tensor(logits, tf.float32)
model_out = tf.add(logits, epsilo