@INPROCEEDINGS{wang2023freematch,
title = {FreeMatch: Self-adaptive Thresholding for Semi-supervised Learning},
author = {Wang, Yidong and Chen, Hao and Heng, Qiang and Hou, Wenxin and Fan, Yue and and Wu, Zhen and Wang, Jindong and Savvides, Marios and Shinozaki, Takahiro and Raj, Bhiksha and Schiele, Bernt and Xie, Xing},
booktitle = {ICLR},
year = {2023},
pages = {1--20}
}
1. 摘要
-
Semi-supervised Learning (SSL) has witnessed great success owing to the impressive performances brought by various methods based on pseudo labeling and consistency regularization.
半监督学习(semi-supervised learning
)的两大杀器,pseudo labeling
(伪标记)和consistency regularization
(一致性正则)。
-
However, we argue that existing methods might fail to utilize the unlabeled data more effectively since they either use a pre-defined / fixed threshold or an ad-hoc threshold adjusting scheme, resulting in inferior performance and slow convergence.
提出现有方法的不足,use a pre-defined / fixed threshold or an ad-hoc threshold adjusting scheme
。这里的阈值应该是伪标签加入的阈值,只有大于阈值的伪标签才会加入训练,这一点在基于伪标签技术的方法中十分的常用。至于基于一致性正则中是否也存在这样的阈值,这一点是存疑的。
-
We first analyze a motivating example to obtain intuitions on the relationship between the desirable threshold and model’s learning status. Based on the analysis, we hence propose FreeMatch to adjust the confidence threshold in a self-adaptive manner according to the model’s learning status.
顺利提出自己的核心创新点: self-adaptive confidence threshold。
-
We further introduce a self-adaptive class fairness regularization penalty to encourage the model for diverse predictions during the early training stage.
一个trick,避免模型初期过早收敛。
-
Extensive experiments indicate the superiority of FreeMatch especially when the labeled data are extremely rare. FreeMatch achieves 5.78%, 13.59%, and 1.28% error rate reduction over the latest state-of-the-art method FlexMatch on CIFAR-10 with 1 label per class, STL-10 with 4 labels per class, and ImageNet with 100 labels per class, respectively. Moreover, FreeMatch can also boost the performance of imbalanced SSL.
自信的算法用三句话来描述自己是art-of-state的。
-
The codes can be found at https: //github.com/microsoft/Semi-supervised-learning.
代码地址。
2. 算法描述
2.1. 例子
通过一个分类的例子,有以下有趣的结论:
- 简单地说,未标记数据利用率(采样率) 1 − P ( Y p = 0 ) 1−P(Y_p = 0) 1−P(Yp=0) 直接由阈值 τ \tau τ 控制。随着置信度阈值 τ \tau τ 变大,未标记数据利用率变低。在训练初期,由于 β \beta β 仍然很小,采用较高的阈值可能会导致采样率较低且收敛速度较慢。
- 更有趣的是,如果
σ
1
≠
σ
2
\sigma_1 \neq \sigma_2
σ1=σ2,则
P
(
Y
p
=
1
)
≠
P
(
Y
p
=
−
1
)
P(Y_p = 1) \neq P(Y_p = −1)
P(Yp=1)=P(Yp=−1)。事实上,
τ
\tau
τ 越大,伪标签越不平衡。从我们旨在解决平衡分类问题的意义上来说,这可能是不可取的。不平衡的伪标签可能会扭曲决策边界并导致所谓的伪标签偏差。对此的一个简单的补救措施是使用特定于类的阈值
τ
2
\tau_2
τ2 和
1
−
τ
1
1 − \tau_1
1−τ1 来分配伪标签。(
different classes have different levels of intra-class diversity (different σ)
) - 采样率
1
−
P
(
Y
p
=
0
)
1 − P(Y_p = 0)
1−P(Yp=0) 随着
μ
2
−
μ
1
\mu_2 − \mu_1
μ2−μ1 变小而降低。换句话说,两个类越相似,未标记的样本就越有可能被屏蔽。随着两个类别变得更加相似,特征空间中会混合更多的样本,而模型对其预测的信心较差,因此需要一个适度的阈值来平衡采样率。否则,我们可能没有足够的样本来训练模型来对已经很难分类的类进行分类。(
some classes are harder to classify than others (µ2 − µ1 being small
)
Since different classes have different levels of intra-class diversity (different σ) and some classes are harder to classify than others (µ2 − µ1 being small), a fine-grained
class-specific threshold
is desirable to encourage fair assignment of pseudo labels to different classes.
2.2. Self-adaptive Threshold
Global Threshold:
τ
t
=
{
1
C
,
if
t
=
0
,
λ
τ
t
−
1
+
(
1
−
λ
)
1
μ
B
∑
b
=
1
μ
B
max
q
b
,
otherwise
.
\tau_t= \begin{cases} \frac{1}{C},& \text{if } t=0,\\ \lambda\tau_{t-1} + (1-\lambda)\frac{1}{\mu B}\sum_{b=1}^{\mu B}\max{q_b},& \text{otherwise}. \end{cases}
τt={C1,λτt−1+(1−λ)μB1∑b=1μBmaxqb,if t=0,otherwise.
Local Threshold:
p
~
t
(
c
)
=
{
1
C
,
if
t
=
0
,
λ
p
~
t
−
1
(
c
)
+
(
1
−
λ
)
1
μ
B
∑
b
=
1
μ
B
q
b
(
c
)
,
otherwise
.
\widetilde{p}_t(c)= \begin{cases} \frac{1}{C},& \text{if } t=0,\\ \lambda\widetilde{p}_{t-1}(c) + (1-\lambda)\frac{1}{\mu B}\sum_{b=1}^{\mu B}q_b(c),& \text{otherwise}. \end{cases}
p
t(c)={C1,λp
t−1(c)+(1−λ)μB1∑b=1μBqb(c),if t=0,otherwise.
Final Threshold:
τ
t
(
c
)
=
MaxNorm
(
p
~
t
(
c
)
)
⋅
τ
t
\tau_t(c) = \text{MaxNorm}(\widetilde{p}_t(c)) \cdot \tau_t
τt(c)=MaxNorm(p
t(c))⋅τt
一致性正则:
L
u
=
1
μ
B
∑
b
=
1
μ
B
I
(
max
(
q
b
)
≥
τ
t
(
arg max
(
q
b
)
)
)
⋅
H
(
q
^
b
,
Q
b
)
\mathcal{L}_u = \frac{1}{\mu B}\sum_{b=1}^{\mu B}\mathbb{I}(\max(q_b) \geq \tau_t(\argmax(q_b)))\cdot \mathcal{H}(\hat{q}_b, Q_b)
Lu=μB1b=1∑μBI(max(qb)≥τt(argmax(qb)))⋅H(q^b,Qb)
Notice:
原文中,指示函数少打了个括号
2.3. Self-adaptive Fairness
KL散度:
D
K
L
(
p
∥
q
)
=
∑
i
=
1
n
p
(
x
i
)
log
(
p
(
x
i
)
q
(
x
i
)
)
D_{KL}(p\|q) = \sum_{i=1}^np(x_i)\log(\frac{p(x_i)}{q(x_i)})
DKL(p∥q)=i=1∑np(xi)log(q(xi)p(xi))
其中:
- p表示样本的真实分布,q表示模型的预测分布。从KL散度公式中可以看到q分布越接近p(q分布越拟合p),那么散度值越小,即损失值越小。
- KL散度称为KL距离,但它并不满足距离的性质:1. KL散度不是对称的;2. KL散度不满足三角不等式。
交叉熵:
D
K
L
(
p
∥
q
)
=
∑
i
=
1
n
p
(
x
i
)
log
(
p
(
x
i
)
q
(
x
i
)
)
=
∑
i
=
1
n
p
(
x
i
)
log
(
p
(
x
i
)
)
−
∑
i
=
1
n
p
(
x
i
)
log
(
q
(
x
i
)
)
=
−
H
(
p
(
x
)
)
+
[
−
∑
i
=
1
n
p
(
x
i
)
log
(
q
(
x
i
)
)
]
\begin{align*} D_{KL}(p\|q) &= \sum_{i=1}^np(x_i)\log(\frac{p(x_i)}{q(x_i)})\\ &= \sum_{i=1}^n p(x_i)\log(p(x_i)) - \sum_{i=1}^n p(x_i)\log(q(x_i)) \\ &= -\mathcal{H}(p(x)) + [-\sum_{i=1}^n p(x_i)\log(q(x_i))] \end{align*}
DKL(p∥q)=i=1∑np(xi)log(q(xi)p(xi))=i=1∑np(xi)log(p(xi))−i=1∑np(xi)log(q(xi))=−H(p(x))+[−i=1∑np(xi)log(q(xi))]
其中:
- 等式的前一部分恰巧就是p的熵(表示信息量),等式的后一部分,就是交叉熵。
- 在机器学习中,我们需要评估label(GroundTruth)和predicts之间的差距,使用KL散度刚刚好,即:由于KL散度中的前一部分
−
H
(
y
)
−\mathcal{H}(y)
−H(y)不变,故在优化过程中,只需要关注交叉熵就可以了。所以一般在机器学习中直接用用交叉熵做loss,评估模型。而在拟合可变分布时,则采用
KL散度
。
本文提出的SAF正则是基于《Joint Optimization Framework for Learning with Noisy Labels》中的工作的。《Pseudo-Labeling and Confirmation Bias in Deep Semi-Supervised Learning》中的工作则是基于此的,并没有做更改。
The regularization loss L p ( θ ∣ X ) L_p(\theta|X) Lp(θ∣X) is required to prevent the assignment of all labels to a single class: In the case of minimizing only Eq. (6), we obtain a trivial global optimal solution with a network that always predicts constant one-hot y ^ ∈ H \hat{y} \in H y^∈H and each label y i = y ^ y_i = \hat{y} yi=y^ for any image x i x_i xi. To overcome this problem, we introduce a prior probability distribution p \mathbf{p} p, which is a distribution of classes among all training data. If the prior distribution of classes is known, then the updated labels should follow the same. Therefore, we introduce the KL-divergence from s ‾ ( θ , X ) \overline{s}(\theta,X) s(θ,X) to p \mathbf{p} p as a cost function as follows:
L p = ∑ j = 1 c p j log ( p j s ‾ ( θ , X ) ) \mathcal{L}_p = \sum_{j=1}^{c}p_j \log(\frac{p_j}{\overline{s}(\theta,X)}) Lp=j=1∑cpjlog(s(θ,X)pj)
s ‾ ( θ , X ) = ∑ i = 1 n s ( θ , x i ) \overline{s}(\theta,X) = \sum_{i=1}^{n}s(\theta,x_i) s(θ,X)=i=1∑ns(θ,xi)
Notice:
这里的正则是KL散度, p表示先验概率,这里采用的是均匀分布;
s
‾
(
θ
,
X
)
\overline{s}(\theta,X)
s(θ,X) 表示模型的平均预测。
SAF:
这个正则的目的是鼓励模型对每个类别做出不同的预测,从而产生有意义的自适应阈值,特别是在标记数据很少的情况下。不同于只是要求模型对于无标记样本预测类别平衡(各个类别预测数量一样),数量波动也是自适应的。SAF 鼓励每个小批量的输出概率在通过直方图分布归一化后接近模型的边缘类分布。
p
‾
=
1
μ
B
∑
b
=
1
μ
B
I
(
max
(
q
b
)
≥
τ
t
(
arg max
(
q
b
)
)
)
Q
b
\overline{p} = \frac{1}{\mu B}\sum_{b=1}^{\mu B}\mathbb{I}(\max(q_b) \geq \tau_t(\argmax(q_b)))Q_b
p=μB1b=1∑μBI(max(qb)≥τt(argmax(qb)))Qb
h
‾
=
1
μ
B
Hist
μ
B
(
I
(
max
(
q
b
)
≥
τ
t
(
arg max
(
q
b
)
)
)
Q
^
b
)
\overline{h} = \frac{1}{\mu B}\text{Hist}_{\mu B}(\mathbb{I}(\max(q_b) \geq \tau_t(\argmax(q_b)))\hat{Q}_b)
h=μB1HistμB(I(max(qb)≥τt(argmax(qb)))Q^b)
h
~
t
=
λ
h
~
t
−
1
+
(
1
−
λ
)
Hist
μ
B
(
q
^
b
)
)
\widetilde{h}_t = \lambda \widetilde{h}_{t-1} + (1-\lambda)\text{Hist}_{\mu B}(\hat{q}_b))
h
t=λh
t−1+(1−λ)HistμB(q^b))
L f = − H ( SumNorm ( q ~ t h ~ t ) , SumNorm ( q ‾ t h ‾ t ) ) \mathcal{L}_f = -\mathcal{H}(\text{SumNorm}(\frac{\widetilde{q}_t }{\widetilde{h}_t }), \text{SumNorm}(\frac{\overline{q}_t }{\overline{h}_t })) Lf=−H(SumNorm(h tq t),SumNorm(htqt))
疑问?
- 这里的 L f \mathcal{L}_f Lf 表示交叉熵的相反数,那不是相当于是负数了。
- 这里魔改太多了,想了很久,还是没法明白这是为什么?
Answer:
我在这里尝试回答一下自己的疑问。
首先,可以肯定的是
L
f
\mathcal{L}_f
Lf是负的。
这里借鉴的
L
f
\mathcal{L}_f
Lf可以理解为
SumNorm
(
q
‾
t
h
‾
t
)
\text{SumNorm}(\frac{\overline{q}_t }{\overline{h}_t })
SumNorm(htqt)的熵,可以理解为信息量。但是信息量不能为负数,这里相当于最大化熵,即要求分布尽可能均匀,每个类的概率值相同。
那么现在就还有一个疑问了,为什么不是类似于 − H ( SumNorm ( q ‾ t h ‾ t ) , SumNorm ( q ‾ t h ‾ t ) ) -\mathcal{H}(\text{SumNorm}(\frac{\overline{q}_t }{\overline{h}_t }), \text{SumNorm}(\frac{\overline{q}_t }{\overline{h}_t })) −H(SumNorm(htqt),SumNorm(htqt))这样的形势。猜测是为了简化运算, SumNorm ( q ~ t h ~ t ) \text{SumNorm}(\frac{\widetilde{q}_t }{\widetilde{h}_t }) SumNorm(h tq t)其实是没有梯度值的,相当于是一个标量。当然这是强行解释了,其实不必纠结为什么吧,效果好就是了,原论文也没有进一步的解释。
3. 总结
论文写作技巧拉满了。关于自适应阈值这一块我倒是理解了,但是关于SAF
这个负号我实在是没法理解。