Focal Loss原理以及代码实现和验证(tensorflow2)

Focal Loss

本博客使用到的完整代码请移步至: 我的github https://github.com/qingyujean/Classification-on-imbalanced-data,求赞求星求鼓励~~~

1. Focal Loss论文解读

    原论文是解决目标检测任务中,前景(或目标)与背景像素点的在量上(1:1000)以及分类的难易程度上的极度不均衡,而导致的one-stage detectos的性能总是上不去的问题。
    由于样本的不均衡,会导致训练过程会被易分类的大量的背景点所主导,而focal loss就好比一个dynamically scaled cross entropy loss,对于易分类的样本,随着其分类正确的置信度的越大的话,其loss会在调制因子的作用下越来越小,逐渐接近0。而对于那些极少数且难分类的负例给予非常大的关注。而且原文中强调,focal loss的在其他样本不均衡的场景,也是有效的(通用性)。

1.1 CE loss

C E ( p , y ) = C E ( p t ) = − l o g ( p t ) ,      w h e r e    p t = { p i f    y = 1 1 − p o t h e r w i s e CE(p, y) = CE(p_t) = -log(p_t), \;\;where\;p_t=\begin{cases}p&if\;y=1\\1-p&otherwise\end{cases} CE(p,y)=CE(pt)=log(pt),wherept={p1pify=1otherwise
其中 y ∈ { − 1 , 1 } y\in\lbrace -1, 1\rbrace y{1,1}. even examples that are easily classifified ( p t ≫ . 5 p_t\gg.5 pt.5) incur a loss with non-trivial magnitude. When summed over a large number of easy examples, these small loss values can overwhelm the rare class. 存在的问题:即使是那些pt远大于0.5的样本,其loss也是不容忽视的,而如果把本来就量大的easy样本的loss加起来,这些小loss汇总起来也会压倒性的淹没少样本对loss的贡献。

1.2 balanced CE loss

α − b a l a n c e d    C E ( p t ) = − α t log ⁡ ( p t ) ,      w h e r e    p t = { p i f    y = 1 1 − p o t h e r w i s e \alpha-balanced\; CE(p_t) = -\alpha_t\log(p_t),\;\; where\;p_t=\begin{cases}p&if\;y=1\\1-p&otherwise\end{cases} αbalancedCE(pt)=αtlog(pt),wherept={p1pify=1otherwise

其中class为1时权重因子为 α ∈ [ 0 , 1 ] \alpha \in [0, 1] α[0,1],class为0时权重因子为 1 − α 1-\alpha 1α
    局限性: Easily classifified negatives comprise the majority of the loss and dominate the gradient. While α balances the importance of positive/negative examples, it does not differentiate between easy/hard examples. Instead, we propose to reshape the loss function to down-weight easy examples and thus focus training on hard negatives(balanced CE 只是从正负例的重要性程度上去影响样本对loss的贡献度,但没有考虑样本分类的难易程度,而focal loss旨在降低easy examples的权重而使得训练更加关注hard negatives (注意:个人理解上,我觉得是正负例中都会存在难分的样本,只是负例量更大,相对难分的样本数可能也会多些,这也是为什么引入γ后最后又还是保留了α的原因,是要让他们在一起共同作用于样本,最终效果才会更好)。

1.3 focal loss

F L ( p t ) = − ( 1 − p t ) γ log ⁡ ( p t ) ,      w h e r e    p t = { p i f    y = 1 1 − p o t h e r w i s e FL(p_t) = -(1-p_t)^\gamma\log(p_t),\;\; where\;p_t=\begin{cases}p&if\;y=1\\1-p&otherwise\end{cases} FL(pt)=(1pt)γlog(pt),wherept={p1pify=1otherwise
其实就是给CE增加了调制因子(modulating factor): ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ,其中 γ ≥ 0 \gamma\ge0 γ0

在这里插入图片描述
Focal Loss:给标准CE加上了一个调制因子: ( 1 − p t ) γ (1-p_t)^{\gamma} (1pt)γ,其中 γ > 0 \gamma>0 γ>0。参数 γ \gamma γ 减少了“易分样本”(well-classified examples, p t > 0.5 p_t>0.5 pt>0.5)的相对损失(relative loss),而给与困难的、误分类样本(hard,misclassified examples)更多的关注。

从图1可以注意到focal loss的两点性质:

  • (1) When an example is misclassifified and pt is small, the modulating factor is near 1 and the loss is unaffected. As pt 1, the factor goes to 0 and the loss for well-classifified examples is down-weighted. (当一个样本被误分类并且pt很小,则调制因子 ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ→1,此时loss和CE的loss没什么太大区别(相比于log(pt)有轻微减小),而如果pt→1(pt>0.5时会正确分类)即对于那些well-classified的样本,其loss则会被显著降低,相应的也就突显出了hard examples的loss)
  • (2) The focusing parameter γ smoothly adjusts the rate at which easy examples are down weighted.(从图上可看出γ能平滑的调整easy examples权重下降的速率)

    当*γ=0时,等价于CE,而当γ增加,调制因子的影响也会增加。γ=2时,实验中效果最好。
i n    p r a c t i c e :      F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) ,      w h e r e    p t = { p i f    y = 1 1 − p o t h e r w i s e in\;practice:\;\;FL(p_t) = -\alpha_t(1-p_t)^\gamma\log(p_t),\;\; where\;p_t=\begin{cases}p&if\;y=1\\1-p&otherwise\end{cases} inpractice:FL(pt)=αt(1pt)γlog(pt),wherept={p1pify=1otherwise
原文中表示在选择α和γ时,最好两个在一起进行调参,因为实验表明岁的γ的增大,α会相应减小。因为focal loss相比CE,负例也贡献了不容忽视的loss,而focal loss中为easy negatives进行了down weight,所以相对而言正例的重要性也无须被“突出”出来了,所以 γ增大了而α减小。

个人理解上,所谓的样本不均衡主要体现在两点:

  1. 样本数量的不均衡(正例极少,负例非常多)
  2. 样本分类的难易程度的不均衡(本文中主要关注hard negative examples,我觉得主要是负例太多,所以hard negative examples也就相对多点,但并不是说正例没有hard examples,而且我觉得那些总是或者容易被误分类的样本就是hard examples)

2. tensorflow2验证focal loss

    本次实验的代码,数据集和模型以及绘图,参考了tensorflow2 官网教程的部分代码。数据集是信用卡欺诈相关的数据,正例表示是一个“欺诈交易”,负例表示一个正常交易。正例只占总数据量的0.17%,可见数据严重不均衡。

2.1 focal loss实现

    focal loss的实现,为了理解上更清晰点,严格按照论文中的公式实现。其中使用log(probs)函数时为了避免出现probs=0而导致log计算出错,使用了tf.clip_by_value对概率值进行了限制,使其必须>=一个很小的值,如下面的输入参数epsilon。

#公式:L(pt) = -αt(1-pt)^γ log(pt),
# pt=p and αt=α  when y=1 ,pt=1-p and αt=1-α when y=-1或者0 视情况而定
def focal_loss(alpha=0.5, gamma=1.5, epsilon=1e-6):
    print('*'*20, 'alpha={}, gamma={}'.format(alpha, gamma))
    def focal_loss_calc(y_true, y_probs):
        positive_pt = tf.where(tf.equal(y_true, 1), y_probs, tf.ones_like(y_probs))
        negative_pt = tf.where(tf.equal(y_true, 0), 1-y_probs, tf.ones_like(y_probs))
        
        loss =  -alpha * tf.pow(1-positive_pt, gamma) * tf.math.log(tf.clip_by_value(positive_pt, epsilon, 1.)) - \
            (1-alpha) * tf.pow(1-negative_pt, gamma) * tf.math.log(tf.clip_by_value(negative_pt,  epsilon, 1.))

        return tf.reduce_sum(loss)
    return focal_loss_calc

2.2 α 和 γ \alpha和\gamma αγ调参

调参时,对与 α ∈ [ 0.1 , 0.4 ] 和 γ ∈ [ 1. , 4. ] \alpha\in[0.1, 0.4]和\gamma\in[1., 4.] α[0.1,0.4]γ[1.,4.]进行了调参:

alphas = np.arange(0.1, 0.41, 0.05)#[0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]
gammas = np.arange(1., 4.1, 0.5)#[1.0, 1.5, 2., 2.5, 3., 3.5, 4.]

    对每一组 α 和 γ \alpha和\gamma αγ都训练模型,并对模型进行了评估,评价指标有:TP、FP、TN、FN、accuracy、precision、recall以及AUC曲线。模型就是简单的DNN。评价的结果会写入csv文件中,以便于观察和比较,综合选出最佳 α 和 γ \alpha和\gamma αγ

调参部分代码

initial_bias = np.log([pos/neg])
model = make_model(output_bias = initial_bias, loss_func='focal_loss')
initial_weights = model.get_weights()#bias=np.log([pos/neg])

all_results = []

for i in range(len(alphas)):
    for j in range(len(gammas)):
        
        model.set_weights(initial_weights)#重新初始化模型
        
        model.compile(
            optimizer=tf.keras.optimizers.Adam(lr=1e-3),
            loss=focal_loss(alpha=alphas[i], gamma=gammas[j]),
            metrics=METRICS,
            run_eagerly=True)##############

        focalloss_history = model.fit(
            train_features,
            train_labels,
            batch_size=BATCH_SIZE,
            epochs=EPOCHS,
            callbacks = [early_stopping],
            validation_data=(val_features, val_labels)
        ) 
        
        #评估
        focal_results = model.evaluate(test_features, test_labels, batch_size=BATCH_SIZE, verbose=0)
        
        focal_metric_res = {'alpha': alphas[i], 'gamma': gammas[j]}
        for name, value in zip(model.metrics_names, focal_results):
            print(name, ': ', value)
            focal_metric_res[name] = value
        print()

        all_results.append(focal_metric_res)

res_df = pd.DataFrame(all_results)
res_df.to_csv('./files/alphas_and_gammas.csv', sep=',', index=False, encoding='UTF-8')

我们来看看评估结果:
在这里插入图片描述
    论文中最佳参数是 α = 0.25 和 γ = 2. \alpha=0.25和\gamma=2. α=0.25γ=2.,对应上图中绿色部分,但在本数据集上并不是最佳的,黄色部分含有较高的AUC值,并且FN和FP相对少。尤其对于这种数据集来说,降低FN也许显得更为重要。例如本例中的信用卡欺诈识别,一个FN可能会引起欺诈交易通过,而一个FP又会使得正常交易会被识别成欺诈行为,而给客户发验证邮件,引起客户的不满意。二者需要权衡。而对于某种疾病比如癌症的识别,FN会错误的把一个癌症患者诊断为健康,这可能会带来巨大的代价。

3. 实现结果说明

作为对比试验,同样是参考了tensorflow2 官网教程“添加经验bias”(作为了baseline)、“类别加权”、“上采样”等方式,进行对比,其分别在训练集和测试集的AUC曲线结果如下:

在这里插入图片描述
在这里插入图片描述
上图可以看出,各路方法都在baseline的基础上有所提升,而且focal loss在precision上的表现还是非常不错的,真实应用时可根据具体情况选择解决方案。

4. 完整代码

完整代码请移步至: 我的github https://github.com/qingyujean/Classification-on-imbalanced-data,求赞求星求鼓励~~~

最后:如果本文中出现任何错误,请您一定要帮忙指正,感激~

参考

[1] tensorflow2 官网教程
[2] focal loss论文

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值