High-Frequency Component Helps Explain the Generalization of Convolutional Neural Networks
题目:高频成分有助于CNN泛化能力的解释
来源:CVPR 2020 oral 卡内基梅隆大学
1. Motivation
-
CNN的泛化能力不直观,不能直接对其进行解释。
-
目前对CNN泛化能力进行解释的方法有针对随机梯度下降、不同的复杂性度量等方面的研究,本文是基于傅里叶变换的角度,从频域入手解释CNN进行特征提取的性质。
2. Contribution
核心观点:CNN可以比人类从更高粒度的层面观察图像的特征,换句话说,就是CNN可以从高频信息中对图像进行观察,而人类只能从低频信息中进行观察,这一差异导致了CNN会做出对人类而言不够直观的泛化表现。
基于以上观点,本文展示了:
-
CNN如何利用高频成分在鲁棒性以及准确性上进行权衡;
-
以图像频谱为工具,提出了一些假设解释CNN的几种泛化行为;
-
从频率角度给出提升CNN对简单攻击鲁棒性的方法。
3. CNN会在鲁棒性和准确性上进行衡量
结论:在通过样本训练出的CNN网络中,一定存在一个样本,使得CNN在任意距离度量方法与鲁棒性阈值下都无法同时满足鲁棒性和准确性均为1的要求,从而导致CNN会在二者之间进行一定的权衡。
为了证明这一结论,本文提出了两个假设:
假设1:在对图像进行观察时,人类只能从低频信息中做出预测,而CNN网络同时会利用高频和低频信息。
(人类倾向于使用数据的语义信息(低频)对图像进行观察,而CNN不仅会使用低频的语义信息,还能使用与语义信息具有特定相关性的高频信息联合对图像进行观察)
假设2:对于一个CNN模型而言,必然存在一个样本 < x , y > <\mathrm{x}, \mathrm{y}> <x,y>,使得: f ( x ; θ ) ≠ f ( x l ; θ ) f(\mathrm{x} ; \theta) \neq f\left(\mathrm{x}_{l} ; \theta\right) f(x;θ)=f(xl;θ)
假设的证明:
作者用ResNet18在Cifar10上分别使用图像的高频和低频部分进行预测,实验结果发现CNN模型在低频成分上的预测很不准确,高频成分的预测结果与原始图像更加吻合,同时证明了假设1和假设2的成立。
(上图中,第一列为原始图像及对应的预测结果,第二列为图像的低频成分,第三列为高频成分。)
3.1 CNN会利用高频组件
图像信息
x
x
x可以视为是高频信息
x
h
x_h
xh与低频信息
x
l
x_l
xl的组合:
x
=
[
x
h
,
x
l
]
x=[x_h, x_l]
x=[xh,xl]
对于高频和低频信息,首先将
x
x
x进行傅里叶变换,得到频域表示
z
z
z,之后使用一个阈值函数
t
(
z
;
r
)
t(\mathrm{z} ; r)
t(z;r)进行高频分量和低频分量的区分,将高频分量和低频分量再使用傅里叶逆变换回时域,即得到高频和低频信息:
z
=
F
(
x
)
,
z
l
,
z
h
=
t
(
z
;
r
)
x
l
=
F
−
1
(
z
l
)
,
x
h
=
F
−
1
(
z
h
)
\begin{aligned} \mathrm{z}=\mathcal{F}(\mathrm{x}), & \mathrm{z}_{l}, \mathrm{z}_{h}=t(\mathrm{z} ; r) \\ \mathrm{x}_{l}=\mathcal{F}^{-1}\left(\mathrm{z}_{l}\right), & \mathrm{x}_{h}=\mathcal{F}^{-1}\left(\mathrm{z}_{h}\right) \end{aligned}
z=F(x),xl=F−1(zl),zl,zh=t(z;r)xh=F−1(zh)
t
(
z
;
r
)
t(\mathrm{z} ; r)
t(z;r)具有一个超参数
r
r
r,具体实现时,通过计算每一个像素点
(
i
,
j
)
(i, j)
(i,j)与其傅里叶变换之后的中心
(
c
i
,
c
j
)
(c_i, c_j)
(ci,cj)之间的欧氏距离,作为区分依据:
z
l
(
i
,
j
)
=
{
z
(
i
,
j
)
,
if
d
(
(
i
,
j
)
,
(
c
i
,
c
j
)
)
≤
r
0
,
otherwise
z
h
(
i
,
j
)
=
{
0
,
if
d
(
(
i
,
j
)
,
(
c
i
,
c
j
)
)
≤
r
z
(
i
,
j
)
,
otherwise
\begin{array}{l} \mathbf{z}_{l}(i, j)=\left\{\begin{array}{ll} \mathbf{z}(i, j), & \text { if } d\left((i, j),\left(c_{i}, c_{j}\right)\right) \leq r \\ 0, & \text { otherwise } \end{array}\right. \\ \mathbf{z}_{h}(i, j)=\left\{\begin{array}{ll} 0, & \text { if } d\left((i, j),\left(c_{i}, c_{j}\right)\right) \leq r \\ \mathbf{z}(i, j), & \text { otherwise } \end{array}\right. \end{array}
zl(i,j)={z(i,j),0, if d((i,j),(ci,cj))≤r otherwise zh(i,j)={0,z(i,j), if d((i,j),(ci,cj))≤r otherwise
根据假设1,人类对图像进行预测的方法为:
y
:
=
f
(
x
;
H
)
=
f
(
x
l
;
H
)
\mathbf{y}:=f(\mathbf{x} ; \mathcal{H})=f\left(\mathbf{x}_{l} ; \mathcal{H}\right)
y:=f(x;H)=f(xl;H)
CNN网络的预测模式为:
arg
min
θ
l
(
f
(
x
;
θ
)
,
y
)
=
arg
min
θ
l
(
f
(
{
x
l
,
x
h
}
;
θ
)
,
y
)
\underset{\theta}{\arg \min } l(f(\mathbf{x} ; \theta), \mathbf{y})= \underset{\theta}{\arg \min } l(f(\{\mathbf{x}_l, \mathbf{x}_h \}; \theta), \mathbf{y})
θargminl(f(x;θ),y)=θargminl(f({xl,xh};θ),y)
基于上述推断,可以得到以下结论:
-
CNN会挖掘高频成分,并在此基础上进行泛化,由于高频信息难以被人类观察,导致了CNN的泛化能力对于人类而言很不直观;
-
假如针对高频成分进行扰动生成对抗性样本,会导致CNN的预测结果不准确,而对于人类而言没有什么变化。
3.2 CNN会在鲁棒性与准确性之间进行权衡
本文定义的CNN预测准确性:
E
(
x
,
y
)
α
(
f
(
x
;
θ
)
,
y
)
\mathbb{E}_{(\mathbf{x}, \mathbf{y})}\alpha\left(f\left(\mathbf{x} ; \theta\right), \mathbf{y}\right)
E(x,y)α(f(x;θ),y)
鲁棒性:
E
(
x
,
y
)
min
x
′
:
d
(
x
′
,
x
)
≤
ϵ
α
(
f
(
x
′
;
θ
)
,
y
)
\mathbb{E}_{(\mathbf{x}, \mathbf{y})} \min _{\mathbf{x}^{\prime}: d\left(\mathbf{x}^{\prime}, \mathbf{x}\right) \leq \epsilon} \alpha\left(f\left(\mathbf{x}^{\prime} ; \theta\right), \mathbf{y}\right)
E(x,y)x′:d(x′,x)≤ϵminα(f(x′;θ),y)
根据假设1和假设2,必然存在一个样本
<
x
,
y
>
<\mathrm{x}, \mathrm{y}>
<x,y>,不能同时满足准确性和鲁棒性均为1,故而CNN会对二者进行一个权衡。
4. 从数据角度挖掘CNN的行为
目前的研究表明,CNN对数据具有记忆能力,这意味着理论上CNN是可以通过记忆数据的信息满足训练精度的要求,而通常CNN并不采用简单的记忆来学习数据的特征,而是在鲁棒性与准确性之间进行权衡,保证一定的泛化能力。
此外,已经有研究发现CNN对于标签混淆数据(标签和图像类别不一一对应)的学习能力很强,根据前面的假设,正确标注的标签与图像的低频信息相关联,而标签混淆数据的低频信息不再与图像标签相对应,但是CNN仍然可以学得很好。
基于以上两个发现,本文对CNN的数据学习模式做了以下推理:
在正确标注的样本中,CNN倾向于先学习低频信息,再逐渐提取高频信息,以提升精度;
在混淆标签样本中,低频信息不再与标签相关联,所以模型对于低频信息与高频信息不再区别对待,意味着CNN开始记忆数据本身的信息。
(可以理解为对于正确标注的样本,CNN可以像做题一样建立理论与实际题目之间的相关性,而对于标签混淆样本,CNN就像面对不会做的题,为了应付考试只能把题背下来。)
为了验证以上想法的正确性,作者分别利用正确标注和混淆标注的CIFAR10数据的低频信息,在ResNet18上训练分类器,观察其收敛情况,如下图所示:
从上图中可以发现,标签混淆数据的模型收敛速度比正常标注的模型更慢,可以认为CNN确实更倾向于从信息中学习,而不是记忆;而在包含更少的低频信息时(r=4/8),正常标注数据训练的模型精度比标签混淆数据更高,而当r较大时,二者的精度没有明显差异,这意味着CNN优先从低频信息进行学习,而标签混淆信息没有低频信息与标签之间的学习关系,只能将低频信息与高频信息同等处理,在r更小,总的信息量更少的情况下,学习效果更差。(作者在MNIST,FashionMNIST,ImageNet上都进行了类似实验,得到差不多的结果)
这里衍生出一个新的问题,即CNN通过对混淆标签数据的学习表明其可以从高频和低频信息中提取到有效信息,但是为什么CNN仍然倾向于从低频信息中学习呢?
作者认为是因为标签信息是人工标注的,即人为地把低频信息与标签关联起来了,因此对于给定标签信息的数据,低频信息对CNN loss的影响更大。
为了验证这一观点,作者在正确标注数据下,用图像的高频和低频成分分别训练了一个分类器,且使其具有较高的预测精度,之后用这两个分类器在原始测试集上进行预测,最后发现在低频成分上训练的分类器泛化能力显著强于高频成分的分类器,证明了其猜想的合理性:
5.从频率角度研究训练手段的作用
5.1 Batch Size
作者设置了不同的batch size,并使用不同的频率成分数据进行训练,得到下图的测试结果:
观察“Train”与“Test”两条曲线,可以发现小的batch size有助于提高精度,而大batch size有助于缩小模型泛化差距;观察代表高频成分的点状线,发现更大的batch size中,不同r值的预测精度差距更小 ,意味着大batch size中高频成分的变化更小,从而缩小了泛化差距;观察代表低频成分的实线,发现大batch size使得loss更陡峭,可能是因为大batch size中包含更多的低频可归纳信息(没有很理解这一点)。
5.2 训练技巧
作者比较了Dropout,Mix-up,BatchNorm和Adversarial Training四种方法对模型预测精度的影响:
发现Dropout和Mix-up对精度影响不大,但可以看出Mix-up捕获了更多的高频信息,对抗性学习的引入会降低精度,但是缩小了泛化差距(诱导CNN重视鲁棒性,进而牺牲精度)。
BN加速了模型的收敛速度,同时有助于捕获高频成分信息。作者指出,BN的核心作用就是平衡不同频率分量的比重,一般来说,高频信息比低频信息低几个量级,因此,BN注重于对高频成分的捕捉,为了说明这一点,作者对BN进行了进一步的实验:
作者对比了加BN与不加BN的网络,在只使用低频成分训练时的表现,发现加了BN之后模型没有明显提升:
5.3 Networks
作者比较了LeNet、AlexNet、VGG、ResNet,最终发现ResNet有更好的准确率,更小的泛化差距和更弱的捕捉高频信息的能力。
5.4 Optimizer
作者比较了SGD、ADAM、AdaGrad、AdaDelta、RMSprop五种优化器,发现SGD更倾向于捕捉高频信息,其它的没有明显差异。
6. 从对抗攻防研究卷积核
根据前面的研究已经知道,对抗性学习有助于提升模型的鲁棒性,作者通过研究发现,与一般的CNN模型相比,经过对抗性学习的模型具有更加光滑的卷积核,可视化结果如下:
卷积核的光滑度与频率有什么关系呢?更加光滑的卷积核意味着卷积核内部的权重之间具有更小的突变,在对特征或图像进行卷积时,将会产生更少的高频信息。
由此很容易得知,若想增加模型的鲁棒性,可以通过一定手段改善卷积核的光滑程度,作者提出了如下的trick:
定义卷积核中的每个元素都有八个相邻元素,将当前位置的权值加上临近八个相邻位置的权值即可使得卷积核更加光滑:
w
i
,
j
=
w
i
,
j
+
∑
(
h
,
k
)
∈
N
(
i
,
j
)
ρ
w
h
,
k
\mathbf{w}_{i, j}=\mathbf{w}_{i, j}+\sum_{(h, k) \in \mathcal{N}(i, j)} \rho \mathbf{w}_{h, k}
wi,j=wi,j+(h,k)∈N(i,j)∑ρwh,k
上图中经过光滑处理之后的卷积核的可视化结果如下:
作者使用这两个方法处理了两个模型,发现进行卷积核光滑操作后,模型精度下降,但是无论是对抗性模型还是单纯模型的鲁棒性均有所提升:
7. 低频成分对目标检测任务的影响
作者将频率方法应用于目标检测任务时,没能很好地对其进行解释:
在只使用低频成分进行目标检测时,模型在下图的上半部分获得了更低的MAP值,而在下半部分的图像中获得了更好的MAP结果,而对于人类而言没有明显区别,显示出CNN与人类思维模式的仍有差异。