[ICLR 2024] Do Generated Data Always Help Csontrastive Learning? 阅读笔记

摘要:

对比学习常用于无监督视觉表征学习中,需要依赖于大量手工标注的增强数据,而各种生成模型可以作为数据膨胀data inflation的手段

然而它们其实可能对对比学习有害,作者分析了原因,揭示了数据膨胀和数据增强的反比关系

作者对性能下降现象做了理论解释,主要推理了数据膨胀前提下的泛化边界,并首次启发,提出了Adaptive Inflation (AdaInf),是一种以数据为中心的数据膨胀策略

作者采用SimCLR在CIFAR上进行测试,使用AdaInf得到了效果提升

Chapter1 Intro

文中指出,数据膨胀(data inflation)就是简单使用生成模型产出的数据文中指出,数据膨胀就是简单使用生成模型产出的数据

而数据增强data augmentation是指对数据进行一系列操作(如裁剪,旋转),增加正负对样本以促进对比学习性能的手段

作者针对data inflation和data augmentation两方面进行性能下降的原因研究,发现data inflation中生成图像的质量作用有限,调整real和generated数据的比例可以改善性能。但在data augmentation方面,作者意外发现在采用data inflation的情况下,较弱的data augmentation竟可以提高性能

为了解释这一现象,作者剖析了data inflation和data augmentation的互补作用,并基于相关见解提出了Adaptive Inflation (AdaInf)策略,可以适应性调整数据增强强度和数据膨胀的混合比例,在不带来额外计算的前提下提高下游任务的性能。

Chapter3 分析性能下降的影响因素

关于data inflation

define: D d \mathcal{D}_d Dd: real data D g \mathcal{D}_g Dg: generated data

distribution of D d \mathcal{D}_d Dd and D g \mathcal{D}_g Dg: P d , P g P_d, P_g Pd,Pg -----------> total overall distribution: P t = β P d + ( 1 − β ) P g P_t = \beta P_d + (1-\beta) P_g Pt=βPd+(1β)Pg

where β = ∣ D d ∣ ∣ D d ∣ + ∣ D g ∣ \beta = \frac{|D_d|}{|D_d| + |D_g|} β=Dd+DgDd

若效果越好,则总的数据和真实数据的差异应当越小越小,该差异可在分布空间中体现,而分布的差异可以用全变分距离(total variation distance)

全变分距离介绍:https://en.wikipedia.org/wiki/Total_variation_distance_of_probability_measures

T V ( P , Q ) = 1 2 ∫ ∣ d P ( x ) − d Q ( x ) ∣ TV(P,Q)=\frac{1}{2} \int |\mathrm{d}P(x)- \mathrm{d}Q(x)| TV(P,Q)=21dP(x)dQ(x) ,一些情况下系数可以省略

因此,可以找到最优化目标: minimize: D t v ( P t , P d ) \mathrm{D}_{\mathrm{tv}}(P_t, P_d) Dtv(Pt,Pd)

该式子中的总分布t包含了真实数据和生成数据,可以转化到仅包含生成数据的分布 P g P_g Pg

定理3.1
D t v ( P t , P d ) = ( 1 − β ) D t v ( P g , P d ) ( 1 ) \mathrm{D}_{\mathrm{tv}}(P_t, P_d)=(1-\beta)\mathrm{D}_{\mathrm{tv}}(P_g, P_d) \quad \quad (1) Dtv(Pt,Pd)=(1β)Dtv(Pg,Pd)(1)

证明:

由此可见,总体数据质量的好坏可以由真实数据比例 β \beta β和生成数据分布 P g P_g Pg决定。

作者使用了不同质量的扩散模型来改进 P g P_g Pg,发现收效甚微

作者调整不同的 β \beta β值,发现真实-生成数据复制比例达到10:1时,可以取得最佳性能
请注意,上面说的10:1指的是复制对应的数据的比例,即真实数据复制10次,生成数据仅1次,之后混合,而不是指真实数据是生成数据的10倍

关于data augmentation

作者进一步研究了数据增强的效果,主要采用的data augmentation策略为随机缩放裁剪,可以改变相对最小裁剪比例 α \alpha α来控制数据增强的强弱大小, a a a越小,强度越大

对比实验采用SimLCR网络运行,通过控制生成数据量来控制data inflation强度,控制 a a a来控制data augmentation强度,作者发现了加入适量的data inflation,同时伴随较弱的data augmentation可以有效提高性能

Chapter4 对数据膨胀,数据增强如何影响性能做理论分析

作者采用了图结构对数据增强进行阐述,把数据样本作为结点,把增强手段作为边

define: 膨胀数据集: X ‾ \overline{\mathcal{X}} X(包含了真实数据和生成数据); 其经过data augmentation的数据集: X \mathcal{X} X
建立关于 X \mathcal{X} X的增强数据图,用邻接矩阵 A ∈ R n × n A \in \mathbb{R}^{n\times n} ARn×n( n n n应该是所有增强样本的数量)表示, A A A表示在数据增强条件下的正样本联合概率
其中,对于由 x ‾ \overline{x} x增强得到的样本 x , x ′ x,x^{\prime} x,x, 有 A x , x ′ = E x ‾ ∼ P X ‾ A ( x ∣ x ‾ ) ⋅ A ( x ′ ∣ x ‾ ) A_{x,x\prime}=\mathbb{E}_{\overline{x} \sim \mathcal{P}_{\overline{\mathcal{X}}}}\mathcal{A}(x|\overline{x})\cdot \mathcal{A}(x^{\prime}|\overline{x}) Ax,x=ExPXA(xx)A(xx),其中 A ( x ∣ x ‾ ) \mathcal{A}(x|\overline{x}) A(xx)表示由 x ‾ \overline{x} x增强的 x x x是正样本的条件概率, E \mathbb{E} E表示求选出数据的期望
引入图拉普拉斯: L = I − D − 1 2 A D − 1 2 \mathcal{L}=I-D^{-\frac{1}{2}}AD^{-\frac{1}{2}} L=ID21AD21,其中 D D D是个对角矩阵,可以定义为: D x x = ∑ x ′ A x x ′ D_{xx}=\sum_{x\prime}A_{xx\prime} Dxx=xAxx, I I I是单位和 D , A D,A D,A一样大小的单位矩阵

若认为生成数据和真实数据是相同的无差异,那么真实数据集 X ‾ r a w \overline{\mathcal{X}}_{\mathrm{raw}} Xraw就可以认为是 X ‾ \overline{\mathcal{X}} X的一个子集,同样仅针对真实数据的增强样本图也可认为是 A A A的一个子图
define:拉普拉斯矩阵 L \mathcal{L} L N N N个特征值: 0 = λ 1 ⩽ λ 2 ⩽ ⋯ ⩽ λ N ⩽ 2 0=\lambda_{1} \leqslant \lambda_{2} \leqslant \cdots \leqslant \lambda_{N} \leqslant 2 0=λ1λ2λN2

作者接下来以线性探测(Linear Probing)任务为例进行原理说明
设有一线性分类器 g f , B g_{f,B} gf,B,如下图示:

其中线性分类器权重矩阵 B ∈ R k × r B \in \mathbb{R}^{k \times r} BRk×r, k k k 是特征通道数, r r r 是类别数
膨胀数据集样本 x ‾ \overline{x} x的类别通过投票分类器决定,即 g ‾ f , B ( x ‾ ) : = a r g m a x i ∈ [ r ] P r x ∼ A ( ⋅ ∣ x ‾ ) ( g f , B ( x ) = i ) \overline{g}_{f,B}(\overline{x}):=\mathrm{argmax}_{i \in [r]} \mathrm{Pr}_{x \sim \mathcal{A}(\cdot|\overline{x})}(g_{f,B}(x)=i) gf,B(x):=argmaxi[r]PrxA(x)(gf,B(x)=i),意思就是对所有 x ‾ \overline{x} x的增强数据做预测,取预测为某一类别次数最多的那个类别作为 x ‾ \overline{x} x的预测结果
define: 分类错误率 ε ( f , B ) \varepsilon(f,B) ε(f,B),值越小,说明准确率越高
定理4.1 至少有 1 − δ 1-\delta 1δ的概率,对于最优的编码器 f ∗ f^* f和学习的分类器权重 B ∗ B^* B,线性探测误差存在以下上界:

ε ( f ∗ , B ∗ ) ⩽ 8 α λ k + 1 + 16 α + 2 ( 1 − β ) D T V ( P d , P g ) ( 2 ) \varepsilon(f^*,B^*) \leqslant \frac{8\alpha}{\lambda_{k+1}} + 16\alpha + 2(1-\beta)\mathrm{D}_{\mathrm{TV}}(P_d,P_g) \quad \quad (2) ε(f,B)λk+18α+16α+2(1β)DTV(Pd,Pg)(2)
其中 α = E x ‾ ∼ P d , x ∼ A ( ⋅ ∣ x ‾ ) 1 [ y ( x ) ≠ y ( x ‾ ) ] \alpha=\mathbb{E}_{\overline{x} \sim \mathcal{P}_{d}, x \sim \mathcal{A}(\cdot|\overline{x})}\mathbb{1}[y(x)\ne y(\overline{x})] α=ExPd,xA(x)1[y(x)=y(x)],即为 x ‾ \overline{x} x增强为 x x x的过程中的标签错误率; λ k + 1 \lambda_{k+1} λk+1 A A A的拉普拉斯矩阵 L \mathcal{L} L的第 k + 1 k+1 k+1小的特征值
定理4.1的公式(2)是理解数据膨胀和数据增强影响模型识别准确率的核心。其值可由 β , D T V ( P g , P d ) \beta,\mathrm{D}_{\mathrm{TV}}(P_g,P_d) β,DTV(Pg,Pd), α \alpha α λ k + 1 \lambda_{k+1} λk+1决定。通过第3章的分析已知混合比例 β \beta β是通过data inflation控制的, D T V ( P g , P d ) \mathrm{D}_{\mathrm{TV}}(P_g,P_d) DTV(Pg,Pd)取决于生成模型本身的质量,二者都可归类为data inflation策略

而决定公式(2)大小的另外两个因素,接下来做详细阐述。
标签错误(Labeling error) 由于random resize crop的作用,增强数据时会对原数据进行部分裁剪,且裁剪后的样本与原样本的标签类别一致,但有可能造成实际真值的变化。比如原本的图片是茶壶,裁剪后的实际内容变成了茶罐,然而标签值依旧是茶壶。而且增强程度越大(即裁剪图越小),越容易得到更局部的图像,增大标签错误率 α \alpha α

图的连通程度(graph connectivity) 根据spectral graph theory, 拉普拉斯特征值可以作为图的连通度的代数衡量,越大的特征值表示图的连通程度越好,因此可以使用 λ k + 1 \lambda_{k+1} λk+1来间接反映连通性。

如图所示,更强的数据增强,即裁剪区域越小,越有可能将原本不同类别的数据增强为事实上同一类别的数据,因而增加了样本间联系,连通性越容易增强。
同时,data inflation因为新生成数据样本,产生更多的相同类别数据,因此也会增加连通性。这可以用图的采样率(即从整个样本中选取部分样本的比例)解释,越小的采样率,采样子图的连通性越小,又因为非数据膨胀的样本图可以看为是带数据膨胀样本图的子集,因而非数据膨胀样本图的连通性小,说明data inflation可以增加连通性。相关理论源可由引理4.2解释
引理4.2 假设 G G G是由 n n n个顶点,spectral gap λ = m i n { λ 2 , 2 − λ N } \lambda=\mathrm{min} \{\lambda_2, 2-\lambda_N\} λ=min{λ2,2λN},结点最小度数为 d m i n d_{min} dmin组成的图, H H H G G G的子图,选择 G G G的边的概率为 p p p,则有
λ H = λ − O ( l o g n p d m i n + ( l o g n ) 3 / 2 p d m i n ( l o g n ) 3 / 2 ) ( 3 ) \lambda_H=\lambda - \mathcal{O}(\sqrt{\frac{ \mathrm{log}n}{pd_{min}}} + \frac{(\mathrm{log}n)^{3/2}}{pd_{min}(\mathrm{log}n)^{3/2}} ) \quad \quad (3) λH=λO(pdminlogn +pdmin(logn)3/2(logn)3/2)(3)
显然, p p p越大,采样率就越大,子图就越大, O ( ) \mathcal{O}() O()就越小, λ H \lambda_H λH就越大,总的连通程度就越大

如上总结,data inflation和data augmentation在影响预测效果的4个指标上存在互补关系。
data augmentation可以提高 λ k + 1 \lambda_{k+1} λk+1来降低 ε ( f ∗ , B ∗ ) \varepsilon(f^*,B^*) ε(f,B),但是其同时也会增大 α \alpha α,这会加大 ε ( f ∗ , B ∗ ) \varepsilon(f^*,B^*) ε(f,B),因此该操作具有冲突性。而data inflation会增大 λ k + 1 \lambda_{k+1} λk+1,但不会影响 α \alpha α.因此,当data inflation使用较为充分, λ k + 1 \lambda_{k+1} λk+1有一定提高时,采用较弱的data augmentation将 λ k + 1 \lambda_{k+1} λk+1进一步提高,同时 α \alpha α不过分增大。而当数据过少,则需要使用较强的data augmentation来充分增大 λ k + 1 \lambda_{k+1} λk+1。适度的对两种数据处理方式进行强度调整,可以在不增加计算复杂度的情况下获得更好的模型效果。

作者提出的AdaInf

在CIFAR上,真实:生成数据的混合比为10:1,同时采用的data augmentation策略较弱:
相对最小裁剪比 a a a:0.08->0.2,ColorJitter(包括增强图片亮度,对比度,饱和度,色调)强度:1->0.5,ColorJitter概率:0.8->0.4

附上定理4.1的证明

在这里插入图片描述
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值