摘要:
对比学习常用于无监督视觉表征学习中,需要依赖于大量手工标注的增强数据,而各种生成模型可以作为数据膨胀data inflation的手段
然而它们其实可能对对比学习有害,作者分析了原因,揭示了数据膨胀和数据增强的反比关系
作者对性能下降现象做了理论解释,主要推理了数据膨胀前提下的泛化边界,并首次启发,提出了Adaptive Inflation (AdaInf),是一种以数据为中心的数据膨胀策略
作者采用SimCLR在CIFAR上进行测试,使用AdaInf得到了效果提升
Chapter1 Intro
文中指出,数据膨胀(data inflation)就是简单使用生成模型产出的数据文中指出,数据膨胀就是简单使用生成模型产出的数据
而数据增强data augmentation是指对数据进行一系列操作(如裁剪,旋转),增加正负对样本以促进对比学习性能的手段
作者针对data inflation和data augmentation两方面进行性能下降的原因研究,发现data inflation中生成图像的质量作用有限,调整real和generated数据的比例可以改善性能。但在data augmentation方面,作者意外发现在采用data inflation的情况下,较弱的data augmentation竟可以提高性能
为了解释这一现象,作者剖析了data inflation和data augmentation的互补作用,并基于相关见解提出了Adaptive Inflation (AdaInf)策略,可以适应性调整数据增强强度和数据膨胀的混合比例,在不带来额外计算的前提下提高下游任务的性能。
Chapter3 分析性能下降的影响因素
关于data inflation
define: D d \mathcal{D}_d Dd: real data D g \mathcal{D}_g Dg: generated data
distribution of D d \mathcal{D}_d Dd and D g \mathcal{D}_g Dg: P d , P g P_d, P_g Pd,Pg -----------> total overall distribution: P t = β P d + ( 1 − β ) P g P_t = \beta P_d + (1-\beta) P_g Pt=βPd+(1−β)Pg
where β = ∣ D d ∣ ∣ D d ∣ + ∣ D g ∣ \beta = \frac{|D_d|}{|D_d| + |D_g|} β=∣Dd∣+∣Dg∣∣Dd∣
若效果越好,则总的数据和真实数据的差异应当越小越小,该差异可在分布空间中体现,而分布的差异可以用全变分距离(total variation distance)
全变分距离介绍:https://en.wikipedia.org/wiki/Total_variation_distance_of_probability_measures
T V ( P , Q ) = 1 2 ∫ ∣ d P ( x ) − d Q ( x ) ∣ TV(P,Q)=\frac{1}{2} \int |\mathrm{d}P(x)- \mathrm{d}Q(x)| TV(P,Q)=21∫∣dP(x)−dQ(x)∣ ,一些情况下系数可以省略
因此,可以找到最优化目标: minimize: D t v ( P t , P d ) \mathrm{D}_{\mathrm{tv}}(P_t, P_d) Dtv(Pt,Pd)
该式子中的总分布t包含了真实数据和生成数据,可以转化到仅包含生成数据的分布 P g P_g Pg
定理3.1
D
t
v
(
P
t
,
P
d
)
=
(
1
−
β
)
D
t
v
(
P
g
,
P
d
)
(
1
)
\mathrm{D}_{\mathrm{tv}}(P_t, P_d)=(1-\beta)\mathrm{D}_{\mathrm{tv}}(P_g, P_d) \quad \quad (1)
Dtv(Pt,Pd)=(1−β)Dtv(Pg,Pd)(1)
证明:
由此可见,总体数据质量的好坏可以由真实数据比例 β \beta β和生成数据分布 P g P_g Pg决定。
作者使用了不同质量的扩散模型来改进 P g P_g Pg,发现收效甚微
作者调整不同的
β
\beta
β值,发现真实-生成数据复制比例达到10:1时,可以取得最佳性能
请注意,上面说的10:1指的是复制对应的数据的比例,即真实数据复制10次,生成数据仅1次,之后混合,而不是指真实数据是生成数据的10倍
关于data augmentation
作者进一步研究了数据增强的效果,主要采用的data augmentation策略为随机缩放裁剪,可以改变相对最小裁剪比例 α \alpha α来控制数据增强的强弱大小, a a a越小,强度越大
对比实验采用SimLCR网络运行,通过控制生成数据量来控制data inflation强度,控制 a a a来控制data augmentation强度,作者发现了加入适量的data inflation,同时伴随较弱的data augmentation可以有效提高性能
Chapter4 对数据膨胀,数据增强如何影响性能做理论分析
作者采用了图结构对数据增强进行阐述,把数据样本作为结点,把增强手段作为边
define: 膨胀数据集:
X
‾
\overline{\mathcal{X}}
X(包含了真实数据和生成数据); 其经过data augmentation的数据集:
X
\mathcal{X}
X
建立关于
X
\mathcal{X}
X的增强数据图,用邻接矩阵
A
∈
R
n
×
n
A \in \mathbb{R}^{n\times n}
A∈Rn×n(
n
n
n应该是所有增强样本的数量)表示,
A
A
A表示在数据增强条件下的正样本联合概率
其中,对于由
x
‾
\overline{x}
x增强得到的样本
x
,
x
′
x,x^{\prime}
x,x′, 有
A
x
,
x
′
=
E
x
‾
∼
P
X
‾
A
(
x
∣
x
‾
)
⋅
A
(
x
′
∣
x
‾
)
A_{x,x\prime}=\mathbb{E}_{\overline{x} \sim \mathcal{P}_{\overline{\mathcal{X}}}}\mathcal{A}(x|\overline{x})\cdot \mathcal{A}(x^{\prime}|\overline{x})
Ax,x′=Ex∼PXA(x∣x)⋅A(x′∣x),其中
A
(
x
∣
x
‾
)
\mathcal{A}(x|\overline{x})
A(x∣x)表示由
x
‾
\overline{x}
x增强的
x
x
x是正样本的条件概率,
E
\mathbb{E}
E表示求选出数据的期望
引入图拉普拉斯:
L
=
I
−
D
−
1
2
A
D
−
1
2
\mathcal{L}=I-D^{-\frac{1}{2}}AD^{-\frac{1}{2}}
L=I−D−21AD−21,其中
D
D
D是个对角矩阵,可以定义为:
D
x
x
=
∑
x
′
A
x
x
′
D_{xx}=\sum_{x\prime}A_{xx\prime}
Dxx=∑x′Axx′,
I
I
I是单位和
D
,
A
D,A
D,A一样大小的单位矩阵
若认为生成数据和真实数据是相同的无差异,那么真实数据集
X
‾
r
a
w
\overline{\mathcal{X}}_{\mathrm{raw}}
Xraw就可以认为是
X
‾
\overline{\mathcal{X}}
X的一个子集,同样仅针对真实数据的增强样本图也可认为是
A
A
A的一个子图
define:拉普拉斯矩阵
L
\mathcal{L}
L的
N
N
N个特征值:
0
=
λ
1
⩽
λ
2
⩽
⋯
⩽
λ
N
⩽
2
0=\lambda_{1} \leqslant \lambda_{2} \leqslant \cdots \leqslant \lambda_{N} \leqslant 2
0=λ1⩽λ2⩽⋯⩽λN⩽2
作者接下来以线性探测(Linear Probing)任务为例进行原理说明
设有一线性分类器
g
f
,
B
g_{f,B}
gf,B,如下图示:
其中线性分类器权重矩阵
B
∈
R
k
×
r
B \in \mathbb{R}^{k \times r}
B∈Rk×r,
k
k
k 是特征通道数,
r
r
r 是类别数
膨胀数据集样本
x
‾
\overline{x}
x的类别通过投票分类器决定,即
g
‾
f
,
B
(
x
‾
)
:
=
a
r
g
m
a
x
i
∈
[
r
]
P
r
x
∼
A
(
⋅
∣
x
‾
)
(
g
f
,
B
(
x
)
=
i
)
\overline{g}_{f,B}(\overline{x}):=\mathrm{argmax}_{i \in [r]} \mathrm{Pr}_{x \sim \mathcal{A}(\cdot|\overline{x})}(g_{f,B}(x)=i)
gf,B(x):=argmaxi∈[r]Prx∼A(⋅∣x)(gf,B(x)=i),意思就是对所有
x
‾
\overline{x}
x的增强数据做预测,取预测为某一类别次数最多的那个类别作为
x
‾
\overline{x}
x的预测结果
define: 分类错误率
ε
(
f
,
B
)
\varepsilon(f,B)
ε(f,B),值越小,说明准确率越高
定理4.1 至少有
1
−
δ
1-\delta
1−δ的概率,对于最优的编码器
f
∗
f^*
f∗和学习的分类器权重
B
∗
B^*
B∗,线性探测误差存在以下上界:
ε
(
f
∗
,
B
∗
)
⩽
8
α
λ
k
+
1
+
16
α
+
2
(
1
−
β
)
D
T
V
(
P
d
,
P
g
)
(
2
)
\varepsilon(f^*,B^*) \leqslant \frac{8\alpha}{\lambda_{k+1}} + 16\alpha + 2(1-\beta)\mathrm{D}_{\mathrm{TV}}(P_d,P_g) \quad \quad (2)
ε(f∗,B∗)⩽λk+18α+16α+2(1−β)DTV(Pd,Pg)(2)
其中
α
=
E
x
‾
∼
P
d
,
x
∼
A
(
⋅
∣
x
‾
)
1
[
y
(
x
)
≠
y
(
x
‾
)
]
\alpha=\mathbb{E}_{\overline{x} \sim \mathcal{P}_{d}, x \sim \mathcal{A}(\cdot|\overline{x})}\mathbb{1}[y(x)\ne y(\overline{x})]
α=Ex∼Pd,x∼A(⋅∣x)1[y(x)=y(x)],即为
x
‾
\overline{x}
x增强为
x
x
x的过程中的标签错误率;
λ
k
+
1
\lambda_{k+1}
λk+1为
A
A
A的拉普拉斯矩阵
L
\mathcal{L}
L的第
k
+
1
k+1
k+1小的特征值
定理4.1的公式(2)是理解数据膨胀和数据增强影响模型识别准确率的核心。其值可由
β
,
D
T
V
(
P
g
,
P
d
)
\beta,\mathrm{D}_{\mathrm{TV}}(P_g,P_d)
β,DTV(Pg,Pd),
α
\alpha
α和
λ
k
+
1
\lambda_{k+1}
λk+1决定。通过第3章的分析已知混合比例
β
\beta
β是通过data inflation控制的,
D
T
V
(
P
g
,
P
d
)
\mathrm{D}_{\mathrm{TV}}(P_g,P_d)
DTV(Pg,Pd)取决于生成模型本身的质量,二者都可归类为data inflation策略
而决定公式(2)大小的另外两个因素,接下来做详细阐述。
标签错误(Labeling error) 由于random resize crop的作用,增强数据时会对原数据进行部分裁剪,且裁剪后的样本与原样本的标签类别一致,但有可能造成实际真值的变化。比如原本的图片是茶壶,裁剪后的实际内容变成了茶罐,然而标签值依旧是茶壶。而且增强程度越大(即裁剪图越小),越容易得到更局部的图像,增大标签错误率
α
\alpha
α。
图的连通程度(graph connectivity) 根据spectral graph theory, 拉普拉斯特征值可以作为图的连通度的代数衡量,越大的特征值表示图的连通程度越好,因此可以使用
λ
k
+
1
\lambda_{k+1}
λk+1来间接反映连通性。
如图所示,更强的数据增强,即裁剪区域越小,越有可能将原本不同类别的数据增强为事实上同一类别的数据,因而增加了样本间联系,连通性越容易增强。
同时,data inflation因为新生成数据样本,产生更多的相同类别数据,因此也会增加连通性。这可以用图的采样率(即从整个样本中选取部分样本的比例)解释,越小的采样率,采样子图的连通性越小,又因为非数据膨胀的样本图可以看为是带数据膨胀样本图的子集,因而非数据膨胀样本图的连通性小,说明data inflation可以增加连通性。相关理论源可由引理4.2解释
引理4.2 假设
G
G
G是由
n
n
n个顶点,spectral gap
λ
=
m
i
n
{
λ
2
,
2
−
λ
N
}
\lambda=\mathrm{min} \{\lambda_2, 2-\lambda_N\}
λ=min{λ2,2−λN},结点最小度数为
d
m
i
n
d_{min}
dmin组成的图,
H
H
H是
G
G
G的子图,选择
G
G
G的边的概率为
p
p
p,则有
λ
H
=
λ
−
O
(
l
o
g
n
p
d
m
i
n
+
(
l
o
g
n
)
3
/
2
p
d
m
i
n
(
l
o
g
n
)
3
/
2
)
(
3
)
\lambda_H=\lambda - \mathcal{O}(\sqrt{\frac{ \mathrm{log}n}{pd_{min}}} + \frac{(\mathrm{log}n)^{3/2}}{pd_{min}(\mathrm{log}n)^{3/2}} ) \quad \quad (3)
λH=λ−O(pdminlogn+pdmin(logn)3/2(logn)3/2)(3)
显然,
p
p
p越大,采样率就越大,子图就越大,
O
(
)
\mathcal{O}()
O()就越小,
λ
H
\lambda_H
λH就越大,总的连通程度就越大
如上总结,data inflation和data augmentation在影响预测效果的4个指标上存在互补关系。
data augmentation可以提高
λ
k
+
1
\lambda_{k+1}
λk+1来降低
ε
(
f
∗
,
B
∗
)
\varepsilon(f^*,B^*)
ε(f∗,B∗),但是其同时也会增大
α
\alpha
α,这会加大
ε
(
f
∗
,
B
∗
)
\varepsilon(f^*,B^*)
ε(f∗,B∗),因此该操作具有冲突性。而data inflation会增大
λ
k
+
1
\lambda_{k+1}
λk+1,但不会影响
α
\alpha
α.因此,当data inflation使用较为充分,
λ
k
+
1
\lambda_{k+1}
λk+1有一定提高时,采用较弱的data augmentation将
λ
k
+
1
\lambda_{k+1}
λk+1进一步提高,同时
α
\alpha
α不过分增大。而当数据过少,则需要使用较强的data augmentation来充分增大
λ
k
+
1
\lambda_{k+1}
λk+1。适度的对两种数据处理方式进行强度调整,可以在不增加计算复杂度的情况下获得更好的模型效果。
作者提出的AdaInf
在CIFAR上,真实:生成数据的混合比为10:1,同时采用的data augmentation策略较弱:
相对最小裁剪比
a
a
a:0.08->0.2,ColorJitter(包括增强图片亮度,对比度,饱和度,色调)强度:1->0.5,ColorJitter概率:0.8->0.4
附上定理4.1的证明