处理类别不平衡的损失函数系列总结

在图像分割领域,我们会常常遇到类别不平衡的问题。比如要分割的目标(前景)可能只占图像的一小部分,因此负样本的比重很大,导致训练结果用来做预测,网络倾向于将样本判断为负样本。这篇博客为大家带来一些用于处理类别不平衡的损失函数的原理讲解和代码实现。时间关系会不断更新,而不是一次性写完。

Weighted cross entropy

思路是想用一个系数描述样本在loss的重要性。对于小数目样本,我们加强它对loss的贡献,对于大数目的样本我们削减它对loss的贡献。是不是和adaboost很像呢?
先看公式

l o s s = − ∑ i w × y i × l o g ( l o g i t s i ) + ( 1 − y i ) × l o g ( 1 − l o g i t s i ) loss = -\sum\limits_ i w \times y_i \times log(logits_i) +(1-y_i)\times log(1-logits_i) loss=iw×yi×log(logitsi)+(1yi)×log(1logitsi)
这和二值交叉熵仅仅有一点变化,就是在正样本的判别上加了一个w系数。w是需要提供的,需要我们预先根据数据集计算。
假设一个分割数据集有20类,每类的样本数目为 n i 个 n_i个 ni i i i从1到20。那么有一种median blance的办法计算w。
求出这20个样本数目的中值,假设是 n x n_x nx。所有的n除以 n x n_x nx得到新的一组系数,这组系数取倒数就得到了对应类别的系数。
具体分析一下。对于样本数目多的类别,除以一个数字 n x n_x nx,按大小排序和之前的序列是一样的,仍然大于很多数字。这时我们取倒数,作为损失权重系数。那么样本数目多的类别权重系数小,对损失函数的贡献就小了;反过来,样本数目小的类别,得到的权重系数大,加强了对损失函数的贡献;而样本数目处于中间的那些类别,权重系数接近1,相当于没加强作用也没削弱其作用。(这是segnet提出的办法)。

numpy代码如下:

import numpy as np

def Wce(logits,label,weight):
    '''
    :param logits:  net's output, which has reshaped [batch size,num_class]
    :param label:   Ground Truth which is ont hot encoing and has typr format of [batch size, num_class]
    :param weight:  a vector that describes every catagory's coefficent whose shape is (num_class,)
    :return: a scalar 
    '''
    loss = np.dot(np.log2(logits)*label,np.expand_dims(weight,axis=1)) + \
           np.log2(logits) * (1-label)
    return loss.sum()

focal loss

focal loss的设计也是很巧妙的。通过对标准的交叉熵做改进,加入了描述样本难易分类的难易程度,并且相对放大对难分类样本的梯度,相对降低对易分类样本的梯度。
l o s s = − a y ( 1 − y ′ ) r × l o g ( y ′ ) − ( 1 − a ) ( 1 − y ) y ′ r × l o g ( 1 − y ′ ) loss = -ay(1-y')^r \times log(y')-(1-a)(1-y)y'^r\times log(1-y') loss=ay(1y)r×log(y)(1a)(1y)yr×log(1y)
其中 y y y是样本的label, y ′ y' y是样本的logits, a , r a,r ar是超参数。

  • a a a是样本平衡因子,在0-1之间
  • r r r的作用是相对放大难分类样本的梯度,相对降低易分类样本的梯度,为0时则是标准的交叉熵
    在Focal Loss中,它更关心难分类样本,不太关心易分类样本,比如:

若 gamma = 2,
对于正类样本来说,如果预测结果为0.97那么肯定是易分类的样本,所以就会很小;
对于正类样本来说,如果预测结果为0.3的肯定是难分类的样本,所以就会很大;
对于负类样本来说,如果预测结果为0.8那么肯定是难分类的样本,就会很大;
对于负类样本来说,如果预测结果为0.1那么肯定是易分类的样本,就会很小。
(此部分内容摘自focal loss

显然对于语义分割任务, a a a的重要性更大,因为类别不平衡才是我们语义分割关心的。设置 a > 0.5 a>0.5 a>0.5,那么就增大了对正样本损失函数值,提高了网络对正样本的重视度。

import numpy as np

def focal_loss(logits,label,a,r):
    '''
    :param logits: [batch size,num_classes] score value
    :param label: [batch size,num_classes] gt value
    :param a: generally be 0.5
    :param r: generally be 0.9
    :return: scalar loss value of a batch
    '''
    p_1 = - a*np.power(1-logits,r)*np.log2(logits)*label
    p_0 = - (1-a)*np.power(logits,r)*np.log2(1-logits)*(1-label)
    return (p_1 + p_0).sum()

dice loss

  • 14
    点赞
  • 71
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值