Long-Tail Learning via Logit Adjustment
文章信息
题目:Long-Tail Learning via Logit Adjustment
发表:ICLR,2021
作者:Aditya Krishna Menon, Sadeep Jayasumana, Ankit Singh Rawat,Himanshu Jain, Andreas Veit, Sanjiv Kumar
背景
标签中的不平衡(imbalanced)问题或者长尾分布(Long tail distribution)一直是多类别分类问题中的一个较为普遍的问题,简单来说,训练集中标签的不平衡会导致最终训练出来的模型有偏差Bias,直接表现就是:模型对于多数类样本的预测结果很好,但是对于少数类样本的预测结果普遍较差。所以,不平衡学习(Imbalanced learning)的目标就是研究如何从不平衡的训练集中学习到一个balanced model, 使得特对于各个类别(特别是那些少数类)的样本预测效果都要很好。
正式地,我们实际上希望最小化如下目标:
动机
当前已有很多解决不平衡分类的方法,比如weight normalization, loss modification, cost-sensitive learning, 并且取得了一些比较好的效果。但是这些方法本质上都是启发式的,各有各的局限性。本文从统计学习角度出发,通过利用训练集中标签分布 P ( y ) P(y) P(y) 来修改loss计算中的logit 输出从而提出了一种balanced loss 函数,也就是所谓的Logit adjustment。
方法
理论
从最最小化BER(f)开始,对于Bayes最优的scorer—
f
∗
f^{*}
f∗,从BER的定义可以看出,BER实际上隐式的使用了
P
b
a
l
P^{bal}
Pbal,故:
假设
P
(
y
∣
x
)
∝
e
x
p
s
y
∗
(
x
)
P(y|x) \propto exp{s^{*}_{y}(x)}
P(y∣x)∝expsy∗(x), 根据,
p
b
a
l
(
y
∣
x
)
∝
p
(
y
∣
x
)
/
p
(
y
)
p^{bal}(y|x) \propto p(y|x) / p(y)
pbal(y∣x)∝p(y∣x)/p(y), 则(7)进一步变形为:
上式表明:我们可以通过利用先验类概率P(y)来修正logit从而最小化balanced loss。
具体怎么做呢? 有两条思路:
(1) 直接训练unbiased model—
P
b
a
l
(
y
)
P^{bal}(y)
Pbal(y): 在训练过程中把P(y)添加了loss function中,训练完成后可直接用来预测。
(2)用naive loss (i.e softmax cross-entropy)训练得到bias model—
P
(
y
)
P(y)
P(y),然后在测试/推理时利用P(y)来调整预测的logit。这属于post-hoc。
下面先来看第一种。
方法一:Logit adjust loss
直接建模
P
b
a
l
(
y
∣
x
)
P^{bal}(y|x)
Pbal(y∣x),
P
b
a
l
(
y
∣
x
)
∝
e
x
p
(
f
y
(
x
)
)
P^{bal}(y|x) \propto exp(f_{y}(x))
Pbal(y∣x)∝exp(fy(x)), 结合
p
b
a
l
(
y
∣
x
)
∝
p
(
y
∣
x
)
/
p
(
y
)
p^{bal}(y|x) \propto p(y|x) / p(y)
pbal(y∣x)∝p(y∣x)/p(y),再引入一个超参数
τ
\tau
τ, loss可以改写为:
其中参数
τ
\tau
τ用于调节第二项的权重。可以看到当
τ
=
0
\tau=0
τ=0时,该loss就退化为了softmax cross-entropy,
τ
\tau
τ越大,则模型训练过程中会更加关注少数类样本。
值得注意的是,该损失与对比学习Constrastive learning中的损失以及pair wise loss非常相似。
方法二: post-hoc 方法
当训练好模型以后,做预测时对logit做如下的事后矫正即可:
思考
- 统计分析那块原文看着比较晦涩,方法实际上非常简单,与以往工作的最大不同是有这个方法是从statistical learning 视角推倒出来的,而非启发式的。
- 参数 τ \tau τ实际上用来调节对少数类样本的关注度,这样一来,不平衡分类的研究似乎变成了一个顾此失彼的游戏:不关注少数类吧,少数类样本的效果很差,关注吧,多数类样本的预测效果会有下降。期待有更深入的研究。
References
1.Menon A K, Jayasumana S, Rawat A S, et al. Long-tail learning via logit adjustment[J]. arXiv preprint arXiv:2007.07314, 2020.