FixMatch 是对现有 SSL 方法的简化. FixMatch 首先对弱增强的未标记图像生成伪标签, 接着, 对同一图像进行强增强后, 再计算其预测分布, 最后计算强增强的预测与伪标签之间的交叉熵损失.
论文地址: FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
代码地址: https://github.com/google-research/fixmatch
会议: NeurIPS 2020
任务: 分类
FixMatch
FixMatch 是 SSL 两种方法的组合: 一致性正则化和伪标签. 它的新颖之处在于这两种方法的组合以及在执行一致性正则化时使用单独的弱增强和强增强.
FixMatch 简要示意图如下:
将弱增强图像输入模型, 当某一预测类别概率高于阈值(虚线)时, 预测将转换为 one-hot 伪标签. 然后, 计算模型对同一图像的强增强的预测. 计算强增强的预测与伪标签之间的交叉熵损失.
文中符号系统如下:
- X = ( ( x b , p b ) ; b ∈ ( 1 , … , B ) ) \mathcal{X}=((x_b,p_b);b\in(1,\dots,B)) X=((xb,pb);b∈(1,…,B)) 为一个 batch_size B B B 的带标签示例.
- U = ( ( u b ) ; b ∈ ( 1 , … , μ B ) \mathcal{U}=((u_b);b\in(1,\dots,\mu B) U=((ub);b∈(1,…,μB) 为一个 batch_size μ B \mu B μB 的无标签示例, 其中 μ \mu μ 是确定 X \mathcal{X} X 和 U \mathcal{U} U 相对大小的超参数.
- p m ( y ∣ x ) p_m(y\vert x) pm(y∣x) 为预测类别分布.
- H ( p , q ) \mathrm{H}(p,q) H(p,q) 为两个概率 p p p, q q q分布之间的交叉熵.
- A ( ) \mathcal{A}() A(), α ( ) \alpha() α() 分别为不同类型的增强.
一致性正则化及伪标签方法简要介绍如下:
Consistency regularization. 关于一致性正则化, 核心就是基于平滑假设, 模型对于对增强后数据的预测应与原始数据预测的结果一致.
Pseudo-labeling. 即利用模型本身来获取未标记数据的人工标签. 更具体地说,
p
b
p_b
pb 的伪标签
q
b
q_b
qb 可以分别定义为基于锐化的连续分布(软)或基于
arg max
\argmax
argmax 操作的独热分布(硬). 在本文里, 人工标签一般指"硬"标签, 并且只保留最大类别概率高于预定阈值的情况. pseudo-labeling 使用如下损失函数:
1
μ
B
∑
b
=
1
μ
b
(
max
(
q
b
)
≥
τ
)
H
(
q
^
b
,
q
b
)
(1)
\frac{1}{\mu B} \sum_{b=1}^{\mu b}(\max(q_b) \geq \tau)\mathrm{H}(\hat{q}_b,q_b) \tag{1}
μB1b=1∑μb(max(qb)≥τ)H(q^b,qb)(1)
其中
q
b
=
p
m
(
y
∣
u
b
)
q_b=p_m(y\vert u_b)
qb=pm(y∣ub),
q
^
b
=
arg max
(
q
b
)
\hat{q}_b=\argmax(q_b)
q^b=argmax(qb),
τ
\tau
τ 为阈值. 鼓励模型的预测是对未标记数据的低熵, 或者说是高置信度.
FixMatch 算法
FixMatch 的损失函数由两个交叉熵损失项组成: 应用于标记数据的监督损失
ℓ
s
\ell_s
ℓs 和无监督损失
ℓ
u
\ell_u
ℓu. 具体来说,
ℓ
s
\ell_s
ℓs 只是弱增强标记示例上的标准交叉熵损失:
ℓ
s
=
1
B
∑
b
=
1
B
B
(
p
b
,
p
m
(
y
∣
α
(
x
b
)
)
)
(2)
\ell_s=\frac{1}{B} \sum_{b=1}^B \mathrm{B}(p_b,p_m(y\vert \alpha(x_b))) \tag{2}
ℓs=B1b=1∑BB(pb,pm(y∣α(xb)))(2)
FixMatch 为每个未标记的示例计算一个人工标签, 然后将其用于标准交叉熵损失. 为了获得人工标签, 首先在给定未标记图像的弱增强版本的情况下计算模型的预测类别分布:
q
b
=
p
m
(
y
∣
α
(
u
b
)
)
q_b =p_m(y \vert \alpha(u_b))
qb=pm(y∣α(ub)). 然后, 使用
q
^
b
=
arg max
(
q
b
)
\hat{q}_b = \argmax(q_b)
q^b=argmax(qb) 作为伪标签, 与
u
b
u_b
ub 的强增强版本做交叉熵损失:
ℓ
u
=
1
μ
B
∑
b
=
1
μ
B
(
max
(
q
b
)
≥
τ
)
H
(
q
^
b
,
p
m
(
y
∣
A
(
u
b
)
)
)
(3)
\ell_u=\frac{1}{\mu B} \sum_{b=1}^{\mu B} (\max(q_b)\geq \tau) \mathrm{H}(\hat{q}_b,p_m(y\vert \mathcal{A}(u_b))) \tag{3}
ℓu=μB1b=1∑μB(max(qb)≥τ)H(q^b,pm(y∣A(ub)))(3)
综上, FixMatch 的损失函数定义为:
l
o
s
s
=
ℓ
s
+
λ
ℓ
u
loss=\ell_s+\lambda\ell_u
loss=ℓs+λℓu. 完整的算法如下:
- 1.计算弱增强标签数据集上的交叉熵损失 ℓ s \ell_s ℓs.
- 2.对每一个 μ B \mu B μB batch, 计算弱增强无标签数据集上的预测分布及伪标签 q b q_b qb, q ^ b \hat{q}_b q^b.
- 3.计算无标签数据交叉熵损失 ℓ u \ell_u ℓu
- 4.得到目标函数总损失 ℓ s + λ ℓ u \ell_s+\lambda\ell_u ℓs+λℓu.
FixMatch 中使用的增强方法
FixMatch 利用了两种增强: “弱"和"强”.
- 弱增强是一种标准的翻转和移位增强策略. 例如在数据集上以 50% 的概率随机水平翻转图像, 并且在垂直和水平方向上随机平移.
- 对于"强"增强, 文中尝试了两种基于 AutoAugment 的方法, 然后是 Cutout. AutoAugment 使用强化学习来查找包含来自 Python Imaging Library 的转换的增强策略. 这需要标记数据来学习增强策略, 这使得在可用标记数据有限的 SSL 设置中使用存在问题. 因此, 使用不需要利用标记数据学习增强策略的 AutoAugment 变体, 例如 RandAugment 和 CTAugment. RandAugment 和 CTAugment 都没有使用学习策略, 而是为每个样本随机选择转换. 对于 RandAugment, 控制所有失真严重程度的幅度是从预定义的范围内随机采样的. 具有随机幅度的 RandAugment 也被用于 UDA. 而对于 CTAugment, 单个变换的幅度是即时学习的.
其他
一些其他重要因素会影响 SSL 的性能, 例如: architecture, optimizer, training schedule 等. 经过实验, 文中发现正则化尤为重要. 在所有的模型和实验中, 使用简单的权重衰减正则化. 同时发现使用 Adam 优化器会导致更差的性能, 而使用 SGD 则没有这种情况, 另外, 使用 SGD 和使用 Nesterov 之间没有存在实质性差异. 对于学习率, 使用余弦学习率衰减. 它将学习率设置为 η cos 7 π k 16 K \eta \cos \frac{7\pi k}{16K} ηcos16K7πk, 其中 η \eta η 是初始学习率, k k k 是当前训练步长, K K K 是总学习率训练步骤. 最后, 使用模型参数的指数移动平均值(EMA)报告最终性能.
FixMatch 可以很容易地使用 SSL 文献中的技术进行扩展. 例如, 来自 ReMixMatch 的增强锚定和分布对齐. 此外, 可以用与模态无关的增强策略, 例如 MixUp 或对抗性扰动代替 FixMatch 中的强增强. 对抗性扰动在 VAT, Adversarial Dropout 中已经应用. MixUp 也在 MixMatch, ICT 中成功应用.