Learning CNN on ViT: A Hybrid Model to Explicitly Class-specific Boundaries for Domain Adaptation
Learning CNN on ViT: A Hybrid Model to Explicitly Class-specific Boundaries for Domain Adaptation 这篇文章提出了一种结合卷积神经网络(CNN)和视觉转换器(ViT)的混合模型,称为“显式类别特定边界模型”(Explicitly Class-specific Boundaries, ECB)。该模型旨在充分利用CNN和ViT各自的优势,特别是在领域适应(Domain Adaptation, DA)任务中,通过显式寻找类别特定的决策边界来提高性能。
代码链接:https://github.com/dotrannhattuong/ECB
主要贡献
-
混合模型的设计:该研究引入了一种混合模型,通过在ViT上学习CNN,充分利用ViT的全局特征捕捉能力和CNN的局部特征提取优势。ViT用于识别更一般的类别特定边界,而CNN则通过聚类目标特征来减少错误。
-
新颖的学习策略:该方法使用一种称为“寻找到征服”的策略,通过最大化和最小化分类器输出之间的差异来优化模型。具体来说,ViT负责扩展类别特定边界,而CNN则在这些边界的基础上进行特征聚类。
-
在DA任务中的优势表现:实验结果表明,与传统的DA方法相比,该混合模型在多个DA基准数据集上取得了更优的表现,验证了其有效性。
-
引言
-
在计算机视觉领域,深度学习是用于处理和理解图像和视频数据的一种先进技术。**卷积神经网络(CNN)和视觉转换器(ViT)**是目前深度学习中最常用的两种模型架构。这篇文章的引言部分介绍了这两种架构的不同特点,并解释了为什么有必要将它们结合起来,提出一种新的混合模型。
什么是卷积神经网络(CNN)?
CNN是一种专门处理图像的神经网络。你可以把它想象成一种“图像侦探”,它擅长从图像的细节中找到有用的信息。例如,在一张猫的照片中,CNN可以很容易地识别出猫的耳朵、眼睛和尾巴等具体的部分。
CNN通过使用卷积层来提取这些细节信息。卷积层是一种数学运算,可以帮助模型聚焦于图像中的局部特征,比如边缘、角落和纹理。这种局部特征的提取能力使得CNN在处理小规模数据集时表现非常好,因为它能够有效地利用数据中的每一部分来学习。
然而,CNN也有一个局限性,那就是它只擅长捕捉图像的局部信息,而不擅长理解整个图像的全局结构或整体背景。这意味着,当处理更复杂的视觉任务,尤其是需要理解整个图像中的不同部分如何相互关联时,CNN可能表现不如预期。
什么是视觉转换器(ViT)?
为了克服CNN的局限性,研究人员引入了视觉转换器(ViT)。ViT是一种使用**自注意力机制(Self-Attention Mechanism)**来处理图像的模型。简单来说,自注意力机制是一种让模型能够“关注”图像中不同部分并理解它们之间关系的方法。
与CNN不同,ViT并不是通过局部特征来理解图像的,而是通过全局的方式来看待图像。例如,在看猫的照片时,ViT不仅会注意到猫的耳朵和眼睛,还会看到整个背景,比如沙发或窗帘。这样,它就能更好地理解图像的全局上下文和不同部分之间的关系。
ViT的这种全局特征提取能力使得它在处理大规模数据集时表现非常出色,特别是在那些需要对整个图像有整体理解的任务中,ViT往往能取得更好的效果。
为什么需要结合CNN和ViT?
既然CNN和ViT各自都有优点和缺点,那么为什么不将它们结合起来,利用它们各自的优势呢?这就是本文的核心思想。
文章指出,尽管ViT在处理大数据集时具有优势,但它在小数据集上容易过拟合,意思是它可能会因为数据太少而记住每一个细节,反而无法很好地泛化到新数据。另一方面,CNN在小数据集上表现良好,但由于它缺乏对全局信息的理解能力,在需要更多上下文的任务中可能表现不佳。
因此,作者提出了一种新的方法:将CNN和ViT结合起来,构建一个混合模型。这种模型可以同时利用CNN的局部特征提取能力和ViT的全局特征捕捉能力,从而在不同类型的数据和任务中都能表现出色。
方法
- Methodology部分详细描述了作者提出的混合模型的设计和训练过程。这个模型被称为“显式类别特定边界模型”(Explicitly Class-specific Boundaries, ECB),它结合了卷积神经网络(CNN)和视觉转换器(ViT),以利用这两种架构各自的优势。
-
1、数据集设置
在领域适应(Domain Adaptation, DA)任务中,通常我们有两个数据集:
- 源域(Source Domain):这是一组有标签的数据,也就是说,每个样本都已经知道其所属类别。
- 目标域(Target Domain):这是模型要适应的目标数据集,它包括一些没有标签的数据(在无监督领域适应中)和一些带标签的数据(在半监督领域适应中)
-
符号说明:
2、模型架构
- 编码器(Encoder):用于提取输入数据的特征。ViT编码器擅长提取全局特征,而CNN编码器擅长提取局部特征。
- 分类器(Classifier):用于将提取到的特征映射到各个类别上。通过优化分类器的输出,模型可以学习更好地区分不同类别的样本。
-
每个编码器和分类器的作用如下:
模型的训练过程分为三个阶段:监督训练、寻找到征服(Finding to Conquering)策略、以及协同训练(Co-training)。
①监督训练(Supervised Training)
在这个阶段,我们对模型进行初始训练,使用的是源域和带标签的目标域数据。
对于ViT分支,使用标准的交叉熵损失函数(Cross-Entropy Loss)来最小化带标签数据的经验损失:
-
类似地,CNN分支的训练也是通过最小化交叉熵损失来实现的:
-
这个阶段的主要目的是让两个分支分别学习如何从数据中提取有用的特征,并对样本进行分类。
-
②寻找到征服策略(Finding to Conquering)
这个策略分为两个子阶段:寻找阶段和征服阶段。
寻找阶段(Finding Stage)
征服阶段(Conquering Stage)
③ 协同训练(Co-training)
在“寻找到征服”策略之后,ViT和CNN分支之间可能仍然存在显著的知识差异。为了进一步减少这种差异,作者引入了协同训练策略。
协同训练策略
协同训练通过让两个分支相互学习对方的预测来提升模型的性能。这一过程包括以下两个目标:
- 目标一:减少两个分支之间的差距,使它们能够相互增强,提高伪标签的质量。
- 目标二:利用ViT分支捕捉数据中复杂模式和关系的潜力,专门提升CNN分支的性能。
-
测试阶段(Testing Phase)
-
作者在本章提出的ECB模型通过结合ViT和CNN的优势,设计了一种新颖的混合架构。这种架构不仅通过显式类别特定边界的策略来优化模型的性能,还利用了协同训练策略来减少模型内部的知识差异。通过这种方法,模型能够在领域适应任务中表现得更加优越。
实验结果
-
实验结果部分展示了本文提出的ECB模型在标准领域适应(Domain Adaptation, DA)基准数据集上的性能表现。作者对比了多种最先进的(State-of-the-Art, SOTA)方法,验证了ECB模型的有效性。
实验设置(Experiment Setup)
在开始具体实验结果之前,先了解实验设置是很重要的。
-
数据集:作者在两个标准DA基准数据集上进行了广泛评估:
- Office-Home:包含65个类别的四个域:Real(真实世界图片,R)、Clipart(剪贴画,C)、Art(艺术品,A)和Product(产品图片,P)。
- DomainNet:包含126个类别的四个域:Real(真实世界图片,rel)、Clipart(剪贴画,clp)、Painting(绘画,pnt)和Sketch(素描,skt)。
-
实验场景:
- 无监督领域适应(Unsupervised Domain Adaptation, UDA):仅使用有标签的源域数据和无标签的目标域数据。
- 半监督领域适应(Semi-Supervised Domain Adaptation, SSDA):使用有标签的源域数据、少量有标签的目标域数据和无标签的目标域数据。
-
模型架构和优化:
- ViT编码器:使用ViT/B-16。
- CNN编码器:使用ResNet架构(如ResNet-50、ResNet-34)。
- 优化器:使用随机梯度下降法(Stochastic Gradient Descent, SGD),动量为0.9,权重衰减为0.0005。
- 学习率:ViT的初始学习率为1e-4,CNN的初始学习率为1e-3。
- 分析:ECB模型在所有任务上都取得了最佳性能。例如,在任务C→A、C→R和P→A中,ECB模型的准确率分别比第二好的方法提升了7.7%、8.1%和7.2%。平均而言,ECB模型的准确率比次佳方法高出5.4%。
- 分析:在1-shot和3-shot设置下,ECB模型的表现均优于其他方法。例如,在任务skt→pnt中,ECB模型在3-shot设置下比第二好的方法(G-ABC)提高了9.3%。即使在更为严格的1-shot设置下,ECB方法在任务rel→clp中仍然表现出了3.1%的提高。总体而言,ECB方法在1-shot设置下提升了6.6%,在3-shot设置下提升了7.1%。
-
结果分析
Office-Home 数据集上的结果(UDA设置)
在Office-Home数据集上的无监督领域适应实验中,ECB模型表现显著优于其他方法。下表展示了不同方法在Office-Home数据集上的准确率表现(以百分比表示):
方法 A→C A→P A→R C→A C→P C→R P→A P→C P→R R→A R→C R→P 平均准确率 DANN 45.6 59.3 70.1 47.0 58.5 60.9 46.1 43.7 68.5 63.2 51.8 76.8 57.6 MCD 48.9 68.3 74.6 61.3 67.6 68.8 57.0 47.1 75.1 69.1 52.2 79.6 64.1 MCC 55.1 75.2 79.5 63.3 73.2 75.8 66.1 52.1 76.9 73.8 58.4 83.6 69.4 ECB (CNN) 68.5 85.4 88.3 79.2 86.8 89.0 79.3 66.4 88.5 81.0 71.1 90.4 81.2 DomainNet 数据集上的结果(SSDA设置)
在DomainNet数据集上的半监督领域适应实验中,ECB模型也表现出了优越的性能。下表展示了不同方法在DomainNet数据集上的准确率表现(以百分比表示),包括1-shot和3-shot设置:
方法 rel→clp rel→pnt pnt→clp clp→skt skt→pnt rel→skt pnt→rel 平均准确率 ENT 65.2/71.0 65.9/69.2 65.4/71.1 54.6/60.0 59.7/62.1 52.1/61.1 75.0/78.6 62.6/67.6 MME 70.0/72.2 67.7/69.7 69.0/71.7 56.3/61.8 64.8/66.8 61.0/61.9 76.1/78.5 66.4/68.9 ECB (CNN) 83.8/87.4 85.4/85.6 86.4/87.3 79.7/80.6 83.4/85.6 79.5/81.7 88.7/90.3 83.8/85.5 消融研究(Ablation Study)
为了进一步验证每个组成部分的有效性,作者还进行了消融研究,重点评估了在不同组合和设置下模型的表现。
实验结果清晰地表明,ECB模型通过结合ViT和CNN的优势,并采用显式类别特定边界和协同训练策略,显著提高了领域适应任务的分类精度。
-
协同训练的重要性:实验结果显示,单向学习(如vit→cnn或cnn→vit)的表现低于双向协同训练。例如,在vit→cnn场景中,ViT分支生成伪标签用于训练CNN分支,而ViT分支自身没有从CNN分支中学习到伪标签,导致CNN分支的表现优于ViT分支3.3%。相比之下,当采用双向协同训练时,模型的整体性能达到最佳,两个分支的准确率分别为85.7%和85.5%。
-
架构分析:在不同架构组合(如“CNN + CNN”和“ViT + ViT”)下的性能显示,“ViT + ViT”架构比“CNN + CNN”提高了7.0%的准确率,但仍然缺乏局部信息的学习。而结合了CNN和ViT的混合架构通过协同训练达到了最高的准确率,为85.5%。
-
总结
- 本文提出的混合模型通过结合ViT和CNN的优势,在领域适应任务中显著提高了分类精度。研究表明,这种混合方法能够有效减少数据偏差,改善伪标签的准确性,且在不同的DA基准数据集上都表现出色。
-
创新性:该研究的创新之处在于将ViT和CNN结合起来,利用两者在不同任务上的优势,通过显式类别特定边界的学习策略,提高了模型的适应性和准确性。
-
实践意义:这种混合模型为解决领域适应任务中的数据偏差问题提供了一种有效的途径,具有较高的应用潜力,尤其是在需要处理多源异构数据的复杂视觉任务中。