分类模型:类别不均衡问题之loss设计

0f97030383f766d63396df7e07c130d6.gif

©作者 | Qun

702de6b05042b07a4d321c4e5fe11af5.png

前言

数据类别不均衡是很多场景任务下会遇到的一种问题。比如 NLP 中的命名实体识别 NER,文本中许多都是某一种或者几种类型的实体,比如无需识别的不重要实体;又或者常见的分类任务,大部分数据的标签都是某几类。

而我们又无法直接排除这些很少的类别的数据,因为这些类别也很重要,仍然需要模型去预测这些类别。

4bdb40d24e9074991fbc6d6afedf0d0e.png

数据采样

有时会从数据层面缓解这种类别不均衡带来的影响,主要是过采样和欠采样。

  • 过采样:对于某些类别数据比较少,对它们进行重复采样,以达到相对平衡,重复采样的时候,有时也会对数据加上一点噪声;

  • 欠采样:对于某些类别数据特别多,只使用部分数据,抛弃一些数据;

过采样可能导致这些类别产生过拟合的现象,而欠采样则容易导致模型的泛化性变差。

另外,比较常用的则是结合 ensemble 方法,则将数据切分为 N 部分,每部分都包含数据少的类别的所有样本和数据多的类别的部分样本,训练 N 个模型,最后进行集成。

缺点是,使用 ensemble 则会提高部署成本和带来性能问题。

0156e1679fee782ab5f3b6bcea147cb2.png

损失函数

如果在损失函数方面下功夫,针对这种类别不均衡的场景设计一种 loss,能够兼顾数据少的类别,这其实是一种更理想的做法,因为不会破坏原数据的分布,并且不会带来性能问题(对训练速度可能有轻微的影响,但不影响推理速度)。

下面,我们就介绍在这方面几种常用的 loss 设计。

fcf882046b55049ea9fbc1295e42736b.png


Focal Loss

Focal Loss 是一种专门为类别不均衡设计的loss,在《Focal Loss for Dense Object Detection》这篇论文中被提出来,应用到了目标检测任务的训练中,但其实这是一种无关特定领域的思想,可以应用到任何类别不均衡的数据中。

首先,还是从二分类的cross entropy(CE)loss入手:

a9c2ea9a503a8a1c58959ddd93cdb484.png

为了符号的方便,。

d3884d422e308d3197d000612638955b.png

为模型对于 label=1(ground-truth)的类别的预测概率。

4.1 问题分析

下图的蓝色曲线为原生的 CE loss,容易看出来,那些容易分类()的样本也会产生不小的 loss,但这些大量的容易样本的 loss 加起来,会压过那些稀少类别的样本的 loss。

36e09c0e65dd6318314f332ae711ecd3.png

▲ 图1

4.2 CE改进

针对类别不均衡,普遍的做法是 -balanced CE。给 CE loss 增加一个权重因子 ,正样本权重因子为 ,负样本为 。

35f0fa7f83f0f05c5e3c4f5b7af5e04c.png

实际使用中,一般设置为类别的逆反频率,即频率低的类别权重应该更大,比如稀少的正样本的 为负样本的频率。或者当作一个超参数。

但是,这种做法只是平衡了正负样本的重要性,无法区分容易(easy)样本和困难(hard)样本,这也是类别不均衡的数据集很容易出现的问题:容易分类的样本贡献了大部分的 loss,并且主导了梯度。

因此,Focal Loss 的主要思想就是让 loss 关注那些困难样本,而降低容易样本的重要性。

1fe944286b4c1a7bf6e6707e8686ba89.png

如上式,在 CE 的基础增加一个调节因子 。上图 1 可以看出, 越大,容易样本的 loss 贡献越小。

Focal Loss 具有以下两个属性:

1. 当一个样本被错误分类时,且 很小时(即为困难样本),那么调节因子是接近 1 的,loss 则基本不受影响。而相反的,当 ,分类很好的样本(容易样本),调节因子则会偏于 0,loss 贡献变得很小;

2. 不同的 参数可以平滑地调整容易样本的重要性降低的比率。当 时,则等同于普通的 CE。而当 变大时,那么调节因子的影响也会同样变大,即容易样本的重要性会降低。

论文在实验中,Focal Loss 还保留上述的权重因子 :

e8bc2ca11e4ed02a92986e70aa07e766.png

通常来说,当增加 时, 应该稍微降低。

在作者的实验中, 取得了最佳性能。

这里的 是稀少类别的权重因子,但按照上述 -balanced CE 的分析,稀少类别的权重因子应该更大才对。但 和 是互相作用的,论文经过多次实验,调整 带来的收益更大,而大的 ,应该搭配小的权重因子 。

e51b644069d45d28bc19c0d1570aa814.png

GHM(Gradient Harmonizing Mechanism)

梯度调和机制 Gradient Harmonizing Mechanism 的设计目的也是为了解决不均衡类别的问题而对 loss 函数进行优化,出自《Gradient Harmonized Single-stage Detector》。

5.1 介绍

GHM 同样表述了关于容易样本和困难样本的观点 A:模型从容易分类的样本的到收益很少,模型应该关注那些困难分类的样本,不管它属于哪一种类别,但大量的容易样本加起来的贡献会盖过困难样本,使得训练效益很低;

进一步指出 Focal Loss 的问题,提出不同的观点 B:

1. Focal Loss 存在两个超参数,并且是互相影响,构成许多参数组合,会导致调参需要很多尝试成本;并且,Focal Loss 是一种静态的 loss,那么同一种超参数无法适用于不同的数据分布;

2. 有一些特别困难分类的样本,它们很可能是离群点,加入这些样本的训练,会影响模型的稳定性;

3. 提出了 gradient density(梯度密度)的梯度调和机制,来缓解这种类别不均衡的问题。

下图左展示了上述观点 A,梯度范数 gradient norm 的大小则代表样本的分类难易程度,收益实际即对应为梯度;

下图中和右展示了通过 GHM 的梯度调和之后,容易样本的 gradient norm 会被削弱许多,并且特别困难的样本也会被轻微削弱,分别对应观点 A 和观点 B-2 的解决方案。

df52068fc8621904ef35e55732c7a880.png

5.2 理论

基于这些分析,论文提出了一种梯度调和机制 GHM(Gradient Harmonizing Mechanism),其主要思想是:首先仍然是降权大量容易样本贡献的梯度总和,其次是对于那些特别困难样本即离群点,也应当相对地降权

对于二分类问题,同样的交叉熵 loss 如下:

512f84629437d6d642c11c70e166a115.png

790b07e50bc2bb54e1dc3ee9819dd772.png

569d9b594400cfcf4917d4303caa99bf.png

其中, 为模型的预测概率, 为真实的标签 ground-truth label;

x 为模型 unnormalized 的直接输出, 。

这里的 g 为 gradient norm,可以用来表示一个样本的分类难易程度以及对在全局梯度中的影响程度,g 越大则分类难度越高。

下图 2 展示了在目标检测模型中 gradient norm 的分布情况,表明了容易样本在梯度中会占主导地位,以及模型无法处理一些特别困难的样本,这些样本的数量甚至超过了中等困难的样本,但模型不应过于关注这些样本,因为它们可以认为是离群点。(对应上述观点 A 和观点 B-2)

7ce2eeb1b9bb33aeead60fef894f05ff.png

▲ 图2

为了解决这种 gradient norm 的分布问题,论文提出了一种调和手段:Gradient Density

95fdc856da63441f55a01071ad325a4f.png

a02e847507b68c28b7d2d6b44aa91552.png

其中, 为第 k 个样本的 gradient norm。

g 的梯度密度即 表示落于以 g 为中心,长度为 的中心区域的样本数量,然后除以有效长度进行标准化。

那么,梯度密度调和参数为:

3f2d8906739e61a6f46e04a29aa35fed.png

N 为样本数量。

可以看作是梯度上在第 i 个样本周边样本频率的一种正则化:

1. 如果所有样本的梯度是均匀分布的,那么对于每个样本 i:,意味着每个样本都没起到任何改变

2. 见上图 2,容易样本的频率很高,那么 就会变得很小,起到降低这些样本的权重的效果;并且特别困难样本即离群点的频率会比中等困难的样本频率多,意味着这些离群点的 会相对较小,那么也会相对地轻微降低这些样本的权重;

3. 从第 2 点可以看出,GHM 其实只适用于那些容易样本和特别困难样本的数量比中等困难样本多的场景。

因此,经过 GHM 调和之后的 loss 为:

5a580771069b4148333b16863290a46d.png

5.3 计算优化

很容易算出,GHM 的计算复杂度是 ,论文通过 Unit Region 的方法来逼近原生的梯度密度,大大降低了计算复杂度。

首先,将 gradient norm 的值域空间 [0,1] 划分为 M 个长度为 的 Unit Region。对于第 j 个 Unit Region: 。

接着,让 等于落在 的样本数量;并且,定义 ,即计算 g 所在的 Unit Region 的索引的函数

则,梯度密度的近似函数如下,得到计算复杂度优化的 GHM Loss:

dbc967e80b6dc08efd49f0a0f6797d44.png

5dfff2c9d25bd7d8c9961005c6218edd.png

aacf062922c2717db5f8900d88814242.png

这里怎么理解这种近似思想呢:

1. 先回忆原生 GHM 的梯度密度计算:g 的梯度密度即 表示落于以 g 为中心,长度为 的中心区域的样本数量,然后除以有效长度进行标准化;

2. 将 gradient norm 划分了 M 个 Unit Region 之后,假如第 i 个样本的 落入第 j 个 Unit Region,那么同样落入该 Unit Regio 的样本可以认为是落于以 为中心的中心区域,并且有效长度为 ,即得到上述的近似梯度密度函数。

5.4 结合EMA

最后,在使用 Unit Region 优化之后,还结合 Exponential moving average(EMA)的思想,让梯度密度更加平滑,减少对极端数据的敏感度:

85296bf507340d7dd557d5ee7c10dd9e.png

a2d09aa37f010dc53f0db812fb52cc36.png

为在 t 次遍历中,落入第 j 个 Unit Region 的数量;

即为 EMA 中的 momentum 参数。


5.5 超参数实验

9c60de2fea2ee6cf32f627b0056ec820.png

304e5f1cb95a077d445d2342d763de7c.png

Dice Loss

Dice Loss 来自《Dice Loss for Data-imbalanced NLP Tasks》这篇论文,阐述在 NLP 的场景中,这种类别数据不均衡的问题也是十分常见,比如机器阅读理解machine reading comprehension(MRC),与上述论文表明的观点大致相同:

  • 负样本数量远超过正样本,导致容易的负样本会主导了模型的训练;

  • 另外,还指出交叉熵其实是准确率(accuracy)导向的,导致了训练和测试的不一致。在训练过程中,每一个样本对目标函数的贡献是相同,但是在测试的时候,像分类任务很重要的指标 F1 score,由于正样本数量很少,每一个正样本就对于 F1 score 的贡献则更多了。


6.1 Dice Coeffificient

dice coeffificient 是一种 F1 导向的统计,用于计算两个集合的相似度:

3e2ebaa5e0d6a68971701f5bd44e90d5.png

对应到二分类场景中,A 是模型预测为正样本的样本集合,B 是真实的正样本集合。此时,dice coefficient 其实等同于 F1:

795439788a72bdb3f6ad5366b13a16ff.png

对于每一个样本 ,它对应 dice coefficient 的为:

de751197127e36c9dddb52c82997ae1a.png

但是,显而易见,这样会导致负样本()对目标的贡献为 0。因此,为了避免负样本的作用为 0,让训练更加平滑,在分子和分母中同时加入一个因子 :

de5ba83cba294af6cd56152a63904ce5.png

为了更快地收敛,分母可以为平方的形式,那么 Dice Loss(DL)则变为:

(修改为 1-DSC,目的是让 DSC 最大化变成目标函数最小化,这是 loss 函数常用的转换套路了,并且让 loss 为正数)

d9c452797aaa8f80e474170aa5fc93a3.png

另外,以计算 set-level 的 dice coefficient,而不是独立样本的 dice coefficient 加起来,这样可以让模型更加容易学习:

07c9c346b934e24e08b23e59f730c115.png

6.2 自调节

上述未经过平滑的 DSC 公式其实是 F1 的 soft 版本,因为对于 F1 score,只存在正判或误判。模型预估通常以 0.5 为边界来判断是否为正样本:

7be353c2c700c7d9725604573ab4e1e2.png

DSC 使用连续的概率 p,而不是使用二分 ,这种 gap 对于均衡的数据集不是什么大问题。

但是对于大部分为容易的负样本的数据集来说,是存在极端的害处:

  • 容易分类的负样本很容易主导整个训练过程,因为它们的预测概率相对来说更容易接近 0;

  • 同时,模型会变得难以区分困难分类的负样本和正样本,这对于 F1 score 的表现有着很大的负向影响。

为了解决这种问题,DSC 在原来的基础上,给 soft 概率 p 乘上一个衰减因子 :

e5ddc971fdb131556b04e26cf5d8d39e.png

是一个与每一个样本关联的权重,并且在训练过程会动态改变,根据样本的分类难易程度,实现对样本权重的自调节:

对于预测概率接近 0 和 1 的容易样本,该值明显小很多,可以减少模型对这些样本的关注。

6.3 实验结果

26ae1566e23d53719453e9534dcdb0f3.png

dbe0a5f27a4bc5bc37a9bbcd7e3113b0.png

60f9d181629abb396bcf39a3d02cf443.png

Label Smoothing

最后再讲一下标签平滑,它不是针对不均衡类别设计的 loss 优化,但不失为一种提升分类模型泛化能力的有效措施。

出自这篇论文《Rethinking the Inception Architecture for Computer Vision》,它是交叉熵 loss 的另外一种正则化形式:Label Smoothing。


7.1 cross-entropy

在 K 分类模型中,第 k 个 label 的预估概率为:

621be01049f4dcea5f15b9479b28a8ed.png

, 为 logits

ground-truth 真实 label 为: 。

那么,对应的交叉熵 loss 则为:

4ffad86e11b1169de2f1d5df6577fa56.png

7.2 存在问题

对 求导得到梯度为:,并且范围在 -1 到 1

对于我们的交叉熵 loss,最小化则等同于真实标签的最大似然,而仅当 时才能达到最大似然, 当 k=y 时为 1,其他则为 0。

而对于有限值的 是无法达到这种最大似然的情况,但可以接近这种情况,当所有的 ,即当对应 ground-truth 的 logits 远远大于其他的 logits,直观上来看,这是由于模型对自己的预测结果太过于自信了,这会产生以下两个问题:

  1. 它可能会造成过拟合。如果模型学习到了为每个样本的 ground-truth label 赋予完全的概率,那这无法保证泛化性;

  2. 它鼓励最大的 logit 和其他的 logits 差别尽可能大,再加上梯度的边界仅在 -1 到 1,这会降低模型的适应(adapt)能力。


7.3 正则化

基于上述分析,作者提出了一种优化的交叉熵,增加了正则化:label-smoothing regularization

a6581ebbd809681db8bf31151d977eb1.png

393b8e07d988fb23436b9575b5260639.png

其中, 为 [0,1] 的超参数,K 为标签类别数量。

  • 这种方法避免了最大的 logit 比其他 logits 太过于大,给模型增加了正则化,提升了模型的泛化能力;

  • 即使发生这种情况,交叉熵 loss 会变得更大,因为不同于 ,每个 都会贡献 loss。

在论文中,ImageNet 数据集的实验中, 取值为 0.1。

88a07ff9a1df1210443443817ace15f8.png


总结

本文介绍了几种针对类别不均衡的数据集提出的解决思路,主要的观点都是数量很多的类别存在许多容易分类的样本,这些对于模型训练的贡献很小,但由于数据巨大,会主导模型的训练过程。

FocalLoss和DiceLoss思想比较接近,都是为了减少模型对容易样本的关注而进行的loss优化,而GHMLoss除了对容易样本降权,还实现了对特别困难样本的轻微降权,因为特别困难的样本可以认为是离群点。

GHM Loss 仅适用于二分类,而 Focal Loss 和 Dice Loss 很容易扩展到多分类,但实际使用中 Focal Loss对于多分类调参比较困难(每种类别对应的 -balanced,加上 ,参数组合过于多)。

最后介绍的 Label Smoothing 虽然不是针对类别不均衡的问题,但在分类模型中,其效果往往比原生的交叉熵有些小提升。


代码实现

tensorflow 及 torch 的实现:github:

https://github.com/QunBB/DeepLearning/tree/main/Trick/unbalance

更多阅读

293198baa51275accb3c37211f4683fc.png

1f28d7ce588d09bcdd00afc62c20f4f9.png

a52b1951f7ad12d2a0a3292ee1ad6791.png

e1141c3816d31d4317b2ea0405f61889.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

7c7b17ee7afd6d21294c67a8b2eae6db.png

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

·

b2b8d303b0f4453171d5920a0e85ca85.jpeg

  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值