Cross-Domain Empirical Risk Minimization for Unbiased Long-Tailed Classification
文章信息
题目:Cross-Domain Empirical Risk Minimization for Unbiased Long-Tailed Classification
发表: AAAI,2022
作者: Beier Zhu 1 , Yulei Niu 1 * , Xian-Sheng Hua 2 , Hanwang Zhang 1
背景
不平衡分类的研究是一个经久不衰的话题, 从最早的采样方法, 代价敏感学习Cost-sensitive learning, Focal loss, 到最近新提出来的Logit adjustment, LAbel distribution DisEntangling (LADE), PolyLoss, 等等,层出不穷. 虽然研究的思路/出发点不尽相同,但是本质上大都是强调对Minority class的关注.
与上述文章不同,这篇文章揭示了不平衡学习中的本质问题,从跨域的经验误差分析出发提出了一种范化性更强的方法,
动机
作者首先揭示了不平衡分类研究中的一个有意思的现象: 相关研究似乎陷入了head vs. tail game. 具体来说: (1) 对于Naive methods (即没有考虑不平衡的分类方法), long-tailed dataset使得模型biased towards head class, 从而对于少数类效果差; (2) 常见的不平衡方法本质上是通过更加关注tail class, 这意味着head class一定程度上被忽视, 从而模型对于测试集和训练集有同样bias的情形时效果很差.
那么, 如何构建真正unbiased的模型呢? 作者从跨域的经验误差最小化分析的角度出发,提出了一个更加general, 更加简洁的Loss, 实验证明: 该方法通过学习更好的feature representation来训练一个unbiased model, 从而使其在balanced & imbalanced test set上效果都很好。
方法
因果分析
作者首先从因果图分析,
X: 输入图像,
Y: 预测值/label
S: 选择变量
对于左边的图,
X
←
S
→
Y
X\leftarrow S\rightarrow Y
X←S→Y, 选择变量S实际上引入了X与Y之间虚假的相关性, 因此直接学习
P
(
Y
∣
X
)
P(Y|X)
P(Y∣X)不可避免的会引入这种虚假的相关性. 作者删掉了指向X的箭头, 并引入了后干预操作
d
o
(
X
)
do(X)
do(X), 这样一来学习的目标变成了
P
(
Y
∣
d
o
(
X
)
)
P(Y|do(X))
P(Y∣do(X)).
xERM
基于上述的因果分析, 要学习的估计器
f
f
f在干预分布(即引入了
d
o
(
X
)
do(X)
do(X)操作之后的)上的经验风险可以定义为:
如何来计算
P
(
y
∣
d
o
(
x
)
)
P(y|do(x))
P(y∣do(x))? 基于全概率公式,作如下展开:
其中:
s=0: 平衡域
s=1:不平衡域
将公式7带入6:
(假设共有N个samples, 且独立同分布)
这里有两个问题需要考虑:
(1) 如何获得
y
s
=
0
y_{s=0}
ys=0
由于训练集本身是biased, balanced domain实际上是不可见的.
y
s
=
0
y_{s=0}
ys=0实际上是通过训练的balanced model----
p
b
a
l
(
y
∣
x
)
p^{bal}(y|x)
pbal(y∣x)来估计的.
(2) 如何估计样本权重
P
(
x
P
(
x
∣
S
)
\frac{P(x}{P(x|S)}
P(x∣S)P(x,
由于
p
(
x
)
p
(
x
∣
s
)
=
s
s
∣
x
\frac{p(x)}{p(x|s)}=\frac{s}{s|x}
p(x∣s)p(x)=s∣xs,进一步假设
P
(
S
=
1
)
=
P
(
S
=
0
)
P(S=1)=P(S=0)
P(S=1)=P(S=0), 可得:
p
(
x
)
p
(
x
∣
S
)
∝
1
p
(
s
∣
x
)
\frac{p(x)}{p(x|S)}\propto\frac{1}{p(s|x)}
p(x∣S)p(x)∝p(s∣x)1. 因此作者提出采用模型的交叉熵损失来衡量两个域各自的weights:
最终,XERM完整的流程如下:
总结
1.感觉这篇文章最重要的是指出了不平衡分类中存在的head vs. tail game问题,
2. 以往的文章主要都是关于误差分布与标记分布之间的关系, 这篇文章在理论分析并结合实验发现, xERM通过学习更好的特征表示来提升性能. 最终证实xERM实现无偏的原因: 通过调整不平衡域和平衡但不可见域上的经验风险来消除由域选择引起的偏差.
References
- Zhu, Beier, et al. “Cross-domain empirical risk minimization for unbiased long-tailed classification.” Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 36. No. 3. 2022.