最近在进行分类任务的时候,发现了数据存在类别不平衡问题。除了类别不平衡问题之外还有难学样本和易学样本之间的不平衡问题。因此考虑使用了focal loss。这里直接上代码:
def focal_loss(logits, labels, gamma):
'''
:param logits: [batch_size, n_class]
:param labels: [batch_size]
:return: -(1-y)^r * log(y)
'''
softmax = tf.reshape(tf.nn.softmax(logits)