视觉分类任务中处理不平衡问题的loss比较

77f59fe8b3929462ac6232f4e41547cf.gif

向AI转型的程序员都关注了这个号👇👇👇

机器学习AI算法工程   公众号:datayx

在计算机视觉(CV)任务里常常会碰到类别不平衡的问题, 例如:

1. 图片分类任务,有的类别图片多,有的类别图片少

2. 检测任务。现在的检测方法如SSD和RCNN系列,都使用anchor机制。训练时正负anchor的比例很悬殊.

3. 分割任务, 背景像素数量通常远大于前景像素。

从实质上来讲, 它们可以归类成分类问题中的类别不平衡问题:对图片/anchor/像素的分类。

再者,除了类不平衡问题, 还有easy sample overwhelming的问题。easy sample如果太多,可能会将有效梯度稀释掉。

这两个问题通常都会一起出现。如果不处理, 可能会对模型性能造成很大伤害。用Focal Loss里的话说,就是训练不给力, 且会造成模型退化:

(1) training is inefficient as most locations are easy negatives…

(2) the easy negatives can overwhelming training and lead to degenerate models.

如果要处理,那么该怎么处理呢?在CV领域里, 若不考虑修改模型本身, 通常会在loss上做文章, 确切地说,是在样本选择或loss weight上做文章。

常见的解决办法介绍

常见的方法有online的, 也有非online的;有只处理类间不平衡的,有只处理easy example的, 也有同时处理两者的。

Hard Negative Mining, 非online的mining/boosting方法, 以‘古老’的RCNN(2014)为代表, 但在CV里现在应该没有人使用了(吧?)。若感兴趣,推荐去看看OHEM论文里的related work部分。

Mini-batch Sampling,以Fast R-CNN(2015)和Faster R-CNN(2016)为代表。Fast RCNN在训练分类器, Faster R-CNN在训练RPN时,都会从N = 1或2张图片上随机选取mini_batch_size/2个RoI或anchor, 使用正负样本的比例为1:1。若正样本数量不足就用负样本填充。使用这种方法的人应该也很少了。从这个方法开始, 包括后面列出的都是online的方法。

Online Hard Example Mining, OHEM(2016)。将所有sample根据当前loss排序,选出loss最大的N个,其余的抛弃。这个方法就只处理了easy sample的问题。

Oline Hard Negative Mining, OHNM, SSD(2016)里使用的一个OHEM变种, 在Focal Loss里代号为OHEM 1:3。在计算loss时, 使用所有的positive anchor, 使用OHEM选择3倍于positive anchor的negative anchor。同时考虑了类间平衡与easy sample。

Class Balanced Loss。计算loss时,正负样本上的loss分别计算, 然后通过权重来平衡两者。暂时没找到是在哪提出来的,反正就这么被用起来了。它只考虑了类间平衡。

Focal Loss(2017), 最近提出来的。不会像OHEM那样抛弃一部分样本, 而是和Class Balance一样考虑了每个样本, 不同的是难易样本上的loss权重是根据样本难度计算出来的。

从更广义的角度来看,这些方法都是在计算loss时通过给样本加权重来解决不平衡与easy example的问题。不同的是,OHEM使用了hard weight(只有0或1),而Focal Loss使用了soft weight(0到1之间).

现在依然常用的方法特性比较如下:

6a662e0067cc349e8726bef1c5652626.png

接下来, 通过修改过的Cifar数据集来比较这几种方法在分类任务上的表现,当然, 主要还是期待Focal Loss的表现。

实验数据

实验数据集

Cifar-10, Cifar-100。使用Cifar的原因没有别的, 就因为穷,毕竟要像Focal Loss论文里那样跑那么多的大实验对大部分学校和企业来说是不现实的。

处理数据得到类间不平衡

将多分类任务转换成二分类:

new_label = label == 1

原始Cifar-10和100里有很多类别,每类图片的数量基本一样。按照这种方式转变后,多分类变成了二分类, 且正负样本比例相差悬殊:9倍和99倍。

实验模型

一个5层的CNN,完成一个不平衡的二分类任务。使用Cross Entropy Loss,按照不同的方法使用不同的权值方案。以不加任何权重的CE Loss作为baseline。

衡量方式

在这种不平衡的二分类问题里, 准确率已经不适合用来衡量模型的好与坏了。此处使用F-Score作标准.

实现细节

CE(Cross Entroy Loss)

c482c11947f42a77e26d202d2a1f0933.png

OHEM

分为以下三步:
1. 计算ce_loss, 同CE
2. 根据ce_loss排序, 选出top N 个sample:

5d53c8861b315a039266f973d7b39c40.png

1a757b577622ace54a9a3b2710991a05.png

Class Balance CE

形式多种多样,我个人最喜欢使用:

9ad6fc1adc2577483afb873f38a96c17.png

cf77c5e1a691f693ffda6fac42453f9a.png

5cd2e17c24bf5ccf16a6cb95af122339.png

优化方法

  • 最简单的SGD, 初始lr=0.1, 每200,000步衰减一次, 衰减系数为0.1。Cifar-100上focal_loss的初始lr=0.01。

  • batch_size = 128.

实验结果

CIFAR-10:

69bb84f7c617b8dc3945fe7856bc8b70.png

b359d1b5dc8abdc75e64d6898a234db6.png

Focal Loss的一个补丁

对于CIFAR-100,batch_size=128时, 一个batch内可能会一个positive sample都没有, 即n_pos == 0, 这时,paper里用n_pos来normalize loss 的方式就不可行了。测试过两种简单的选择:一是用所有weight之和来normalize, 二是直接不normalize。前者很难训练甚至训练不出来, 后者可用。所以上面的Focal loss计算代码应该补充为:

21c57137781eba96764b631dc6a9e45f.png

经验总结

1231bd9a4e3a23e054193cd1edaa90f1.png

Code Available On Github

https://github.com/dengdan/test_tf_models

Branch:focal_loss

References

Focal Loss for Dense Object Detection, https://arxiv.org/pdf/1708.02002.pdf

RCNN, https://arxiv.org/abs/1311.2524

Fast RCNN, http://arxiv.org/abs/1504.08083

Faster-RCNN, http://arxiv.org/abs/1506.01497

Training Region-based Object Detectors with Online Hard Example Mining, https://arxiv.org/abs/1604.03540

机器学习算法AI大数据技术

 搜索公众号添加: datanlp

36090d26dcf2062914858d3279e7ff02.png

长按图片,识别二维码


阅读过本文的人还看了以下文章:

TensorFlow 2.0深度学习案例实战

基于40万表格数据集TableBank,用MaskRCNN做表格检测

《基于深度学习的自然语言处理》中/英PDF

Deep Learning 中文版初版-周志华团队

【全套视频课】最全的目标检测算法系列讲解,通俗易懂!

《美团机器学习实践》_美团算法团队.pdf

《深度学习入门:基于Python的理论与实现》高清中文PDF+源码

《深度学习:基于Keras的Python实践》PDF和代码

特征提取与图像处理(第二版).pdf

python就业班学习视频,从入门到实战项目

2019最新《PyTorch自然语言处理》英、中文版PDF+源码

《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码

《深度学习之pytorch》pdf+附书源码

PyTorch深度学习快速实战入门《pytorch-handbook》

【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》

《Python数据分析与挖掘实战》PDF+完整源码

汽车行业完整知识图谱项目实战视频(全23课)

李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材

笔记、代码清晰易懂!李航《统计学习方法》最新资源全套!

《神经网络与深度学习》最新2018版中英PDF+源码

将机器学习模型部署为REST API

FashionAI服装属性标签图像识别Top1-5方案分享

重要开源!CNN-RNN-CTC 实现手写汉字识别

yolo3 检测出图像中的不规则汉字

同样是机器学习算法工程师,你的面试为什么过不了?

前海征信大数据算法:风险概率预测

【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类

VGG16迁移学习,实现医学图像识别分类工程项目

特征工程(一)

特征工程(二) :文本数据的展开、过滤和分块

特征工程(三):特征缩放,从词袋到 TF-IDF

特征工程(四): 类别特征

特征工程(五): PCA 降维

特征工程(六): 非线性特征提取和模型堆叠

特征工程(七):图像特征提取和深度学习

如何利用全新的决策树集成级联结构gcForest做特征工程并打分?

Machine Learning Yearning 中文翻译稿

蚂蚁金服2018秋招-算法工程师(共四面)通过

全球AI挑战-场景分类的比赛源码(多模型融合)

斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)

python+flask搭建CNN在线识别手写中文网站

中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程

不断更新资源

深度学习、机器学习、数据分析、python

 搜索公众号添加: datayx  

bfdab4f9f4e04b3d33b4a58994e198cd.png

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值