[Transformer]Efficient Training of Visual Transformers with Small Datasets


NIPS2021

Abstract

ViT正逐渐成为替代CNN的新范式,与CNN不同之处在于ViT可以捕获元素之间的全局关系,并且具备更强大的表征能力。但是ViT缺乏卷积的归纳偏置,为了学习一些局部特征只能借助大量的训练样本,因此对数据量要求更高。本文分析了不同ViT框架在小规模数据集上的鲁棒性,并且发现尽管在ImageNet上预训练具有相近的准确性,但是在小规模数据集上却有较大差异。
此外本文日出一种辅助的自监督方式可以从图像中提取额外的信息,带来的计算开销却可以忽略不计,这就可以使得ViT有效学习图像内的空间关系,并且在小规模数据集上更具鲁棒性。本文将这一方法与监督训练结合进行,并且这种方法不依赖于特定的架构。在不同ViT框架上的实验结果表明本文的方法可以显著提升ViT最终的精度。

Section I Introduction

ViT逐渐在计算机视觉领域成为CNN的替代方案,已成功用于图像分类分割目标检测图像生成等任务,VIT的优点之一是为处理视觉信息和文本信息提供了一个统一的处理范式。这一方面开创性的工作是ViT。


ViT的显著优势是可以使用注意力层来建模token之间的全局关系,这是与CNN主要的区别,因为卷积受到卷积核的限制感受野是局部的。但是ViT的代价就是缺乏CNN的归纳偏置,没有平移不变性和层次化的信息结构。因此ViT对数据的需求更大,需要从大量数据中学习各种模式,因此通常需要在JFT-300M、ImageNet-31K等大型数据集进行预训练。



为了缓解这一问题,有不同改进方法,如将卷积层与注意力层混合,这样可以为VIt提供局部的归纳偏置,并且在大型数据集上取得了一定性能提升,但是在中小数据集上的表现还不清楚。并且实际任务中并不总是能具有和ImageNet相当的数据集。




本文比较了不同的第二代ViT框架在中小数据集上从头训练或者微调的性能。可以发现虽然在ImageNet上性能相近,但是迁移到小规模数据集上性能就不尽相同。并且本文还比较了具有相同容量的ResNet和ViT发现,在大多数情况下小规模数据集上ViT的性能是可以与ResNet相媲美的。




本文提出一种辅助的自监督任务,并且提出响应的损失函数来约束在中小规模数据集上的训练。具体来说就是吴建福学习token之间的空间关系,从嵌入后的网路随机抽取pairs计算二者之间的集合距离。网路需要在进行嵌入的同时编码局部和上下文信息,如果不嵌入局部信息,那么不同输入图像的patch embedding就五大区分;而如果不编码上下文信息那么区分能力就不强,会比较模棱两可。

本文的工作受到ELECTRA的启发,ELECTRA也是需要在NLP任务中能够更高效的提取样本信息。本文莱斯的使用多个token embedding来表征一张图像,然后计算所有可能的pairs相对距离确定物体的位置.这样单张图像在前向传递过程中可以比较多组pairs并且计算他们的平均损失。因此本文与之前需要单独传递input patch的任务不同,本文可以在一个较大的网格中不用计算所有的组合就建模所有依赖关系。


本文提出的辅助之间的损失函数就是dense relative localization loss,不需要额外的注释,将结合交叉熵损失函数一起使用来提升ViT的性能。并且在不同实验中显示总是比基线网络有所提升,最多能提升数十个点。



本文工作总结如下:




(1)比较了不同ViT框架发现在小规模数据集上训练较少epoch他们的性能有很大差距;




(2)本文提出一种相对定位的辅助损失函数用于提升ViT性能;




(3)本文通过大量实验证明可以提升ViT的泛化能力,加速训练,并且是独立于网络结构即插即用的。

Section II Related Work

Visual Transformers


虽然之前也有工作尝试在CNN中使用注意力,但是第一个纯Transformeruan框架还是iGPT和ViT。ViT是一种有监督的训练方式,使用cls token和cls head来执行分类任务,但是计算成本很大并且在大型数据集上训练效果略逊于ResNet。
VidelBER类似iGPT,但是不是处理像素而是将每一帧图像作为一个特征向量整体表示;DeiT则是借助网络蒸馏来训练Vit。
但是如前所述ViT缺乏局部归纳偏置,使得模型不能脱离对大数据集的依赖。


因此第二代ViT聚焦于CNN和Transformer的混合框架,token被reshape会二维这样就可以使用卷积进行嵌入,增加局部信息进去。



因此本文实验中先试用三种较为先进的第二代ViT:T2T,Swin,CvT,每种模型其参数量与ResNet-50接近。
并且与原始VIt相近本文也使用绝对位置编码来提供token之间的顺序信息,因为注意力层和FFN层都是排列不变的,在[35]中使用了相对位置嵌入。本文的相对定位损失函数则作为一个pretext task来提取额外的信息,不需要人工监督。




Self-supervised learning





自监督首先成功用于NLP任务中,作为无需昂贵人工标注的一种替代方法。典型的一种NLP Pretet task是掩盖输入序列的work然后训练网络预测哪个单词被掩膜掉了。
本文受此启发,提出一种图像的pretext task可以通过密集采样ViT 嵌入后的信息来计算标签,并且本文不需要预训练模型也不需要替换输入token,而是会计算token之间的集合距离。
因为NLP中语句长度有限,替换是可行的,但是图像patch是高度连续的,因此很难使用replace这种方法。
在计算机视觉任务中常见的pretext task是从同一图像找那个提取两种不同视角,作为一对正样本来提取共享的语义特征。本文的一大突破是没有提出一种全自监督的方法而是联合使用标准监督和自监督来规范ViT的训练。密集定位损失函数计算的不是pair之间的损失而是当前批次中同一图像的不同视角来计算的,并且这种方法可以与数据增强一起使用。因此本文的pretext task是计算统一图像不同pairs的相对位置。
之前基于定位的自监督方法主要基于预测图像旋转或者patch的相对位置,而本文需要在同一图像中提取多个patch,那么损失也是计算随机pairs之间的相对距离。
在这里插入图片描述

Section III Preliminaries

Fig 1展示了ViT的处理流程,(a)是常规的第二代ViT分类结构,(b)是本文的localization MLP将任意一组token paires级联后作为输入。


原始VIT会将一系列重叠或者互不重叠的图像切片作为输入,每个patch进行投影获得一组KXK的token,然后送入Transformer建模pair之间的关系。
而混合框架一般会将token reshape会二维空间这样可以进行卷积操作,一般使用步长大于1的步长卷积或者池化操作来降低初始KXK大小的分辨率,从而模拟CNN的层次结构。
最终KXK的token输出还有一个cls token用于实现分类任务,cls token会在整个grid范围内手机上下文信息,最后使用一个MLP将所有的head作为输入,输出预测结果。损失函数为交叉熵损失函数。

当本文将relative localization loss插入ViT中时除了使用localization MLP之外没有其他改动,依旧使用交叉熵损失函数,唯一的改动就是对T2T和CvT最终的grid进行下采样,使得他们与Swin的大小一样。Swin中最终gird分辨率为7x7,其余两者为14x14.

Section IV Dense relative localization task

本文的正则化任务是鼓励ViT学习空间信息而不借助额外的人工标注。本文通过对每张图像采样多组embedding pairs然后让网络预测他们之间的相对距离。
具体来说对于一张图像x切分成kxk大小的网格,每一个grid随机采样一组嵌入后的结果(Eij,Eph)然后计算相对偏移量:


在这里插入图片描述

被选到的嵌入vector会级联后送入一个MLP中,MLP包含两层隐藏层,输出层包含2个神经元,用于预测两个位置之间的相对距离,dense relative localization loss定义为:


在这里插入图片描述

对于每一个图像,会随机采样m对,并且计算所有损失的均值。Ldeloc会和交叉熵损失函数一起作为整体loss,并且比重由λ系数调控。
在T2T和CvT中λ=0.1,Swin中λ=0.5
在这里插入图片描述

.
接下来需要讨论额是ViT中使用相对位置嵌入是否足以解决空间定位问题?本文5.2-5.3的实验结果表明如果只使用Ldrloc而不使用位置嵌入精度提升没有那么大,本文之前提到T2T,CvT最终池化后会得到一个7x7的网格Gx,事实是因为分辨率若为14x14会使得Ldrloc收敛的十分缓慢,本文认为这是由于空间定位更加困难,会减慢收敛速度甚至产生噪声梯度。因此本文使用7x7的分辨率。

Section V Experiments

本文所有试验都是图像分类任务,使用11个不同的数据集,分别是ImageNet,CIFAR-10/100,Oxford flowers,SVHN以及剩下六个用于域迁移的数据集,这样就使得从ImageNet迁移到下游数据集十分具有挑战性。
在这里插入图片描述

Table 1展示了数据集的情况。
本文使用T2T,Swin的官方实现,CvT用的是非官方的一个开源实现,对比的是参量相近的ResNet模型,对比结果参见Table 3,除了使用本地MLP之外没有进行其他改进,使用的同样的数据增强等。

Part 1 Ablation Study

Table 2展示了不同搭配下对精度的影响,本文最终确定m=64即采样64组,localization loss权重系数λ=0.1,Swin网络设置为0.5.
其他参数在本文都保持不免主要就是说明与数据集、任务、训练方案无关,强调本文这种损失的易用性。
在这里插入图片描述

Part 2 Training from scratch

本节使用ViT和本文提出的正则化损失来在中小型数据集上训练不同epoch,事实上对于这种应用场景通常会先在ImageNet上进行预训练,但是这种方法并不是通用的,比如对于3D点云数据或者使用指定网络框架时通常需要从头训练。
本文首先测试了在IN-10上的性能,可以看到有了Ldeloc的辅助精度有所提升,尤其在训练较少epoch就能得到提升。正如预期的那样本文的损失函数可以作为一个正则化项,在训练较少epoch时提升尤为明显,而这对VIT的训练尤为重要,因为一般训练ViT会比Resnet更久。
本文在其他数据集上从头开始训练100个epoch,可以看到ViT的性能很大程度上取决于数据集的规模(这是预料之内的)同时也取决于具体的网络结构,可以看到在ImasgeNet上训练的网路top-1精度比较接近,但是从Table 4的结果却可以看出不停网络框架之间存在很大差异,CvT一般比另外两个ViT更加鲁棒性,但是在大规模数据集上这种差别是看不出来的。

在这里插入图片描述

本文在Table 4还对比了使用Ldrloss训练后的准确性,确实使用Ldrloss会对所有网络性能有所提升,说明这种自监督的辅助项可以向ViT提供重要的信息,特别的是使用这种损失可以有效提升从头训练ViT的准确性。
对ResNet的实验表明本文的Ldrloc也有提升,但是由于Reset本身就已嵌入了局部信息,因此在使用定位辅助损失帮助不大。
在这里插入图片描述

Part 3 Fine-tuning

接下来分析典型的需要微调的场景,现在ImageNet上预训练然后在目标域上进行网络调优,验证Ldrloc的作用,结果参见Table 5.


由于CvT没有提供ImageNet预训练的结果所以没有参与对比,可以看到预训练后T2T和Swin之间的性能差异没有那么明显,,并且加入本文的正则化损失后精度会进一步提升。

Section VI Conclusion

本文首先对不同ViT在不同规模数据集上的表现进行分析,发现在中小数据集上性能差异很大,CvT在数据较少时更有效,此外本文还提出一种自监督的损失函数来辅助规范模型的训练,通过密集采样不同的嵌入pair并计算之间的相对距离来引导ViT学习空间信息。

本文在11个数据集进行了广泛实验,发现本文提出的Ldrloss可以显著提升极限精度,有时甚至能提升45个点,这表明Ldrloc是一个有效且易于迁移的性能提升工具,也为研究其他ViT自监督任务、多任务学习提供了解决方法,无需借助大规模数据集。

Limitation

未来将深入分析为什么细粒度嵌入网格对不适合使用辅助任务。此外,本文通过一些ViT-b实验,证明了Ldrloc在大尺寸ViT使用辅助损失的有效性,在实验分析中,主要关注与ResNet-50大小大致相同的VTs。
事实上,本文的目的是研究中小数据集上的ViT行为,因此,在数据匮乏的训练场景中,高容量模型很可能不是最佳选择

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值