概
本文是通过固定教师网络(具有鲁棒性), 让学生网络去学习教师网络的鲁棒特征. 相较于一般的distillation 方法, 本文新加了reweight机制, 另外其损失函数非一般的交叉熵, 而是最近流行的对比损失.
主要内容
本文的思想是利用robust的教师网络
f
t
f^t
ft来辅助训练学生网络
f
s
f^s
fs, 假设有输入
(
x
,
y
)
(x, y)
(x,y), 通过网络得到特征
t
+
:
=
f
t
(
x
)
,
s
+
:
=
f
s
(
x
)
,
t^+:= f^t(x), s^+:=f^s(x),
t+:=ft(x),s+:=fs(x),
则
(
t
+
,
s
+
)
(t^+, s^+)
(t+,s+)构成正样本对, 自然我们需要学生网络提取的特征
s
+
s^+
s+能够逼近
t
+
t^+
t+, 进一步, 构建负样本对, 采样样本
{
x
1
−
,
x
2
−
,
…
,
x
k
−
}
\{x_1^-, x_2^-, \ldots, x_k^- \}
{x1−,x2−,…,xk−}, 同时得到负样本对
(
t
+
,
s
i
−
)
(t^+,s_i^-)
(t+,si−), 其中
s
i
−
=
f
s
(
x
i
−
)
s_i^-=f^s(x_i^-)
si−=fs(xi−). 总的样本对就是
S
p
a
i
r
:
=
{
(
t
+
,
s
+
)
,
(
t
+
,
s
1
−
)
,
…
,
(
t
+
,
s
k
−
)
}
.
\mathcal{S}_{pair} := \{(t^+, s^+), (t^+, s_1^-), \ldots, (t^+, s_k^-)\}.
Spair:={(t+,s+),(t+,s1−),…,(t+,sk−)}.
根据负样本采样的损失, 最大化
J
(
θ
)
:
=
E
(
t
,
s
)
∼
p
(
t
,
s
)
log
P
(
1
∣
t
,
s
;
θ
)
+
E
(
t
,
s
)
∼
q
(
t
,
s
)
log
P
(
0
∣
t
,
s
;
θ
)
.
J(\theta):= \mathbb{E}_{(t,s)\sim p(t,s)} \log P(1|t,s;\theta) + \mathbb{E}_{(t,s)\sim q(t,s)} \log P(0|t,s;\theta).
J(θ):=E(t,s)∼p(t,s)logP(1∣t,s;θ)+E(t,s)∼q(t,s)logP(0∣t,s;θ).
当然对于本文的问题需要特殊化, 既然先验
P
(
C
=
1
)
=
1
k
+
1
,
P
(
C
=
0
)
=
k
k
+
1
P(C=1)=\frac{1}{k+1}, P(C=0)=\frac{k}{k+1}
P(C=1)=k+11,P(C=0)=k+1k, 故
J
(
θ
)
:
=
E
(
t
,
s
)
∼
p
(
t
,
s
)
log
P
(
1
∣
t
,
s
;
θ
)
+
k
⋅
E
(
t
,
s
)
∼
q
(
t
,
s
)
log
P
(
0
∣
t
,
s
;
θ
)
.
J(\theta):= \mathbb{E}_{(t,s)\sim p(t,s)} \log P(1|t,s;\theta) + k\cdot \mathbb{E}_{(t,s)\sim q(t,s)} \log P(0|t,s;\theta).
J(θ):=E(t,s)∼p(t,s)logP(1∣t,s;θ)+k⋅E(t,s)∼q(t,s)logP(0∣t,s;θ).
q ( t , s ) q(t,s) q(t,s)是一个区别于 p ( t , s ) p(t,s) p(t,s)的分布, 本文采用了 p ( t ) q ( s ) p(t)q(s) p(t)q(s).
作者进一步对前一项加了解释
P
(
1
∣
t
,
s
;
θ
)
=
P
(
t
,
s
)
P
(
C
=
1
)
P
(
t
,
s
)
P
(
C
=
1
)
+
P
(
t
)
P
(
s
)
P
(
C
=
0
)
≤
P
(
t
,
s
)
k
⋅
P
(
t
)
P
(
s
)
,
\begin{array}{ll} P(1|t,s;\theta) &= \frac{P(t,s)P(C=1)}{P(t,s)P(C=1) + P(t)P(s)P(C=0)} \\ &\le \frac{P(t,s)}{k\cdot P(t)P(s)}, \\ \end{array}
P(1∣t,s;θ)=P(t,s)P(C=1)+P(t)P(s)P(C=0)P(t,s)P(C=1)≤k⋅P(t)P(s)P(t,s),
故
E
(
t
,
s
)
∼
p
(
t
,
s
)
log
P
(
1
∣
t
,
s
;
θ
)
+
log
k
≤
I
(
t
,
s
)
.
\mathbb{E}_{(t,s)\sim p(t,s)} \log P(1|t,s;\theta) + \log k\le I(t,s).
E(t,s)∼p(t,s)logP(1∣t,s;θ)+logk≤I(t,s).
又
J
(
θ
)
J(\theta)
J(θ)的第二项是负的, 故
J
(
θ
)
≤
I
(
t
,
s
)
,
J(\theta) \le I(t,s),
J(θ)≤I(t,s),
所以最大化
J
(
θ
)
J(\theta)
J(θ)能够一定程度上最大化
t
,
s
t,s
t,s的互信息.
reweight
教师网络一般要求精度(干净数据集上的准确率)比较高, 但是通过对抗训练所生成的教师网络往往并不具有这一特点, 所以作者采取的做法是, 对特征
t
t
t根据其置信度来加权
w
w
w, 最后损失为
L
(
θ
)
:
=
E
(
t
,
s
)
∼
p
(
t
,
s
)
w
t
log
P
(
1
∣
t
,
s
;
θ
)
+
k
⋅
E
(
t
,
s
)
∼
p
(
t
)
p
(
s
)
w
t
log
P
(
0
∣
t
,
s
;
θ
)
,
\mathcal{L}(\theta) := \mathbb{E}_{(t,s)\sim p(t,s)} w_t \log P(1|t,s;\theta) + k\cdot \mathbb{E}_{(t,s)\sim p(t)p(s)} w_t \log P(0|t,s;\theta),
L(θ):=E(t,s)∼p(t,s)wtlogP(1∣t,s;θ)+k⋅E(t,s)∼p(t)p(s)wtlogP(0∣t,s;θ),
其中
w
t
←
p
y
p
r
e
d
=
y
(
f
t
,
t
+
)
∈
[
0
,
1
]
.
w_t \leftarrow p_{ypred=y}(f^t,t^+) \in [0, 1].
wt←pypred=y(ft,t+)∈[0,1].
即
w
t
w_t
wt为教师网络判断
t
+
t^+
t+类别为
y
y
y(真实类别)的概率.
拟合概率 P ( 1 ∣ t , s ; θ ) P(1|t,s;\theta) P(1∣t,s;θ)
在负采样中, 这类概率是直接用逻辑斯蒂回归做的, 本文采用
P
(
1
∣
t
,
s
;
θ
)
=
h
(
t
,
s
)
=
e
t
T
s
/
τ
e
t
T
s
/
τ
+
k
M
,
P(1|t,s;\theta) = h(t,s) = \frac{e^{t^Ts/\tau}}{e^{t^Ts/\tau}+\frac{k}{M}},
P(1∣t,s;θ)=h(t,s)=etTs/τ+MketTs/τ,
其中
M
M
M为数据集的样本个数.
会不会
e
t
T
s
/
τ
e
t
T
s
/
τ
+
γ
⋅
k
M
2
,
\frac{e^{t^Ts/\tau}}{e^{t^Ts/\tau}+\gamma \cdot \frac{k}{M^2}},
etTs/τ+γ⋅M2ketTs/τ,
把
γ
\gamma
γ也作为一个参数训练符合NCE呢?
实验的细节
文中有如此一段话
we sample negatives from different classes rather than different instances, when picking up a positive sample from the same class.
也就是说在实际实验中, t + , s + t^+,s^+ t+,s+对应的类别是同一类的, t + , s − t^+, s^- t+,s−对应的类别不是同一类的.
In our view, adversarial examples are like hard examples supporting the decision boundaries. Without hard examples, the distilled models would certainly make mistakes. Thus, we adopt a self-supervised way to generate adversarial examples using Projected Gradient Descent (PGD).
也就是说, t , s t, s t,s都是对抗样本?
超参数: k = 16384 k=16384 k=16384, τ = 0.1 \tau=0.1 τ=0.1.
疑问
算法中的采样都是针对单个样本的, 但是我想实际训练的时候应该还是batch的, 不然太慢了, 但是如果是batch的话, 怎么采样呢?