为了防御之前提出的FGSM和JSMA的攻击方式,作者根据之前hinton提出的蒸馏学习的方式,再此基础上稍作修改得到了防御蒸馏模型,并理论推导了防御有效性的原因。
蒸馏学习是原先hinton提出用来减少模型复杂度并且不会降低泛化性能的方法,具体就是在指定温度下,先训练一个教师模型,再将教师模型在数据集上输出的类别概率标记作为软标签训练学生模型。而在防御蒸馏模型中,选择两个相同的模型作为教师模型和学生模型。
具体过程如下:
1、先用硬标签训练教师模型,假设温度为
T
T
,则教师模型的softmax层的输出为:
F(X)=[ezi(X)/T∑N−1l=0ezl(X)/T]i∈0…N−1
F
(
X
)
=
[
e
z
i
(
X
)
/
T
∑
l
=
0
N
−
1
e
z
l
(
X
)
/
T
]
i
∈
0
…
N
−
1
再此基础上应用交叉熵损失训练模型。
2、然后用教师模型输出类别概率 F(X) F ( X ) (注意,这里还是保持了温度 T T )实际上,这和,即普通情况下训练模型,并没有什么太大区别,但我们还是与原文保持一致。
3、对于学生模型,我们还是利用温度 T T 下的输出计算交叉熵损失函数,不过类别标签应用之前教师模型输出的软标签,进而进行训练。
对于使用软标签带来的好处,主要在于使用软标签 F(X) F ( X ) 使得神经网络能够在概率向量中找到的附加知识。 这个额外的熵编码了类之间的相对差异。 例如,在手写数字识别的背景下,给定一些手写数字的图像 X X ,模型可以评估数字 7 7 到的概率以及标签 1 1 到的概率,这表明 7 7 和之间有一些结构相似性。
4、对于模型的预测输出,我们反而将温度 T T 降为,从而以高置信度来预测未知输入的类别。
实际上我们根本不需要前面的教师模型,只需要将 F(X) F ( X ) 作为神经网络的输出来最小化交叉熵损失函数既可以达到防御FGSM和JSMA的攻击。
作者也对防御蒸馏的有效性进行的分析,对softmax层的输出求梯度可以很容易的得出:
∂Fi(X)∂Xj∣∣∣T=∂∂Xj(ezi/T∑N−1l=0ezl/T)=1g2(X)(∂ezi(X)/T∂Xjg(X)−ezi(X)/T∂g(X)∂Xj)=1g2(X)ezi/TT(∑l=0N−1∂zi∂Xjezl/T−∑l=0N−1∂zl∂Xjezl/T)=1Tezi/Tg2(X)(∑l=0N−1(∂zi∂Xj−∂zl∂Xj)ezl/T)
∂
F
i
(
X
)
∂
X
j
|
T
=
∂
∂
X
j
(
e
z
i
/
T
∑
l
=
0
N
−
1
e
z
l
/
T
)
=
1
g
2
(
X
)
(
∂
e
z
i
(
X
)
/
T
∂
X
j
g
(
X
)
−
e
z
i
(
X
)
/
T
∂
g
(
X
)
∂
X
j
)
=
1
g
2
(
X
)
e
z
i
/
T
T
(
∑
l
=
0
N
−
1
∂
z
i
∂
X
j
e
z
l
/
T
−
∑
l
=
0
N
−
1
∂
z
l
∂
X
j
e
z
l
/
T
)
=
1
T
e
z
i
/
T
g
2
(
X
)
(
∑
l
=
0
N
−
1
(
∂
z
i
∂
X
j
−
∂
z
l
∂
X
j
)
e
z
l
/
T
)
其中 g(X)=∑N−1l=0ezl(X)/T g ( X ) = ∑ l = 0 N − 1 e z l ( X ) / T ,并且我们记 zi(X) z i ( X ) 为 zi z i 。
但实际上这里的分析值得商榷,我们来重新做一个分析,首先我们假设用防御蒸馏训练得到的防御蒸馏模型类别 i i 的logits输出为(也就是softmax层的输入),而原始模型类别 i i 的logits输出为。模型训练时,为了使交叉熵损失函数足够小,我们类别输出 z′i(X) z i ′ ( X ) 之间的差距必须足够大,原先,我们只需要目标类别 l l 的logits输出比其余类别 zi(X),i≠l z i ( X ) , i ≠ l 大10左右即可。( e−10≈4.5×10−5 e − 10 ≈ 4.5 × 10 − 5 ,这是我随便选的数字),即可使得我们要求的类别 l l 的概率输出近似为1:
而训练防御蒸馏模型时,,假设我们温度选择 T=100 T = 100 ,那么我们在该温度下要求目标类别的概率输出为 1 1 ,则需要满足:
此时若还按照刚才的 e−10 e − 10 的要求,则必须 z′i(X)−z′l(X)=−10∗T=−1000 z i ′ ( X ) − z l ′ ( X ) = − 10 ∗ T = − 1000 ,而训练完毕后模型预测时温度降为 T=1 T = 1 ,此时差1000就是绝对的 0 0 和了,我们用交叉熵 loss=−logFl(X) l o s s = − log F l ( X ) 对 X X 求梯度可以得到:
注意这里是分子分母同时除以 ez′l(X) e z l ′ ( X ) 近似得到的,因此蒸馏训练之后的损失函数对样本X的梯度近似等于0,我们的实验结果也是如此。
同理,我们重新分析对softmax层的输出求的梯度:
∂Fi(X)∂X=∂∂Xj(ez′i∑N−1l=0ez′l)=1g2(X)(∂ez′i(X)∂Xg(X)−ez′i(X)∂g(X)∂X)=1g2(X)ez′i(∑l=0N−1∂z′i∂Xez′l−∑l=0N−1∂z′l∂Xez′l)=ez′ig2(X)(∑l=0N−1(∂z′i∂X−∂z′l∂X)ez′l)=∂z′l∂X−∂z′l∂X=0
∂
F
i
(
X
)
∂
X
=
∂
∂
X
j
(
e
z
i
′
∑
l
=
0
N
−
1
e
z
l
′
)
=
1
g
2
(
X
)
(
∂
e
z
i
′
(
X
)
∂
X
g
(
X
)
−
e
z
i
′
(
X
)
∂
g
(
X
)
∂
X
)
=
1
g
2
(
X
)
e
z
i
′
(
∑
l
=
0
N
−
1
∂
z
i
′
∂
X
e
z
l
′
−
∑
l
=
0
N
−
1
∂
z
l
′
∂
X
e
z
l
′
)
=
e
z
i
′
g
2
(
X
)
(
∑
l
=
0
N
−
1
(
∂
z
i
′
∂
X
−
∂
z
l
′
∂
X
)
e
z
l
′
)
=
∂
z
l
′
∂
X
−
∂
z
l
′
∂
X
=
0
其中 g(X)=∑N−1l=0ez′l(X) g ( X ) = ∑ l = 0 N − 1 e z l ′ ( X ) ,这里也是分子分母同时除以 e2z′l(X) e 2 z l ′ ( X ) 得到的。这也与我们的实验相符(雅可比矩阵为0)。但是实际上JSMA的攻击选择的是logits的输出,因此实际上防御蒸馏根本没办法抵抗JSMA的攻击。
防御蒸馏实际上是一个典型的梯度遮蔽的方式来防御对抗攻击,实际上我们即使不知道真正的梯度,采用近似的梯度,也能够成功攻击。C&W后来提出了一个新的攻击方式成功攻击了该防御模型。