论文题目:https://arxiv.org/pdf/2203.01522.pdf
动机
尽管DNN近期在表示学习中取得了巨大成果,但这些方法严重依赖于大规模数据,大规模数据的收集在现实生活中并不是一件容易的事。因此,如何在数据稀缺的条件下学习鲁棒的深度表示特征收到了广泛关注。下图展示了不同样本之间的关系,相似的种类倾向于共享相似的部分(例如,公鸡、秃鹰共享身体形状和爪形)。 因此,将共享知识从 head/seen classes 转移到 tail/unseen classes 可以促进长尾/零样本学习。 此外,探索属于同一类的图像之间的不变特征也有助于使用少量样本学习鲁棒表示。
创新点
1)从DNN内部结构的角度来探索样本关系
2)提出了BatchFormer,用于探索每个mini-batch中的样本关系
3)大量的实验证明了BatchFormer的有效性,包括长尾识别、零样本学习、域泛化和自监督表示学习
方法论
Overview
不同样本之间的关系是多种多样且复杂的,作者从学习的角度考虑样本关系,目标是使DNN本身能够在端到端的深度表示学习期间从每个小bs中学习样本关系。提出的深度表示学习框架如下图所示。具体来说,首先使用骨干网络来学习单个数据样本的表示,即每个小bs中的不同样本之间没有交互。之后,利用Transformer中的交叉注意力机制对不同样本之间的关系进行建模,称为BatchFormer模块。然后将BatchFormer的输出作为最终分类器的输入。作者为了弥补训练和测试之间的差距,还在BatchFormer模块之前使用了一个辅助分类器,通过在最终分类器和辅助分类器之间共享权重,能够将从样本关系中学到的知识转移到主干和辅助分类器。因此,在测试时可以去掉BatchFormer,直接使用辅助分类器进行分类。
BatchFormer
BatchFormer模块堆叠了多个transformer编码层来学习不同样本之间的关系。
Transformer Encoder
transformer encoder包含了多头自注意力(MSA)及MLP层,后面再接了个LN层。 表示一系列输入特征,N是系列特征的长度,C是特征的维度。transformer编码器的公式如下:
其中,l是transformer编码器层数的索引值。MSA已被广泛用于建模通道和空间维度之间的关系。因此,作者认为它也可以扩展到探索batch维度中的关系。与transformer层的典型用法不同,BatchFormer的输入将首先被reshape,把整个batch视为一个序列。这样做就使transformer层中的自注意力机制就变成了BatchFormer不同样本之间的交叉注意力。
Shared Classifier
由于不能假设测试时的batch统计信息,如样本关系,因此BatchFormer模块之前和之后的特征之间可能存在差异。也就是说,我们无法通过直接移除BatchFormer来对新样本进行推理。因此,除了最终分类器之外,作者还引入了一个新的辅助分类器,这不仅可以从最终分类器中学习,还可以与BatchFormer之前的特征保持一致。为了实现这一点,只需在辅助分类器和最终分类器之间共享参数/权重。作者将这种简单而有效的策略称为“共享分类器”。通过提出的“共享分类器”,模型可以在测试期间移除BatchFormer模块,同时仍然受益于使用BatchFormer的样本关系学习。
BatchFormer在pytorch的伪代码如下所示:
BatchFormer: A Gradient View
为了帮助我们更好地理解BatchFormer是如何通过探索样本间的关系来辅助表示学习,作者还从梯度传播的角度为优化提供了直观的解释。直观地说,如果没有BatchFormer,所有损失只会在相应的样本和类别上传播梯度,即一对一,而使用BatchForme在其他样本上也存在梯度,如下图所示:
具体来说,给定N个样本集X以及对应的N个损失Loss,有:
BatchFormer对小batch中样本之间的关系建模,为每个标签yi隐式增加了 N-1 个虚拟样本,可以被视为一种数据增强方法。
实验结果
长尾识别
零样本学习
域泛化
自监督学习
消融实验之batch size
消融实验之shared classifier
结论
作者提出让深度神经网络本身能够探索每个小batch size中样本之间关系。首先将mini-batch中的每张图像(batch dimension)视为一个序列的一个节点,然后在这些图像之间建立一个Transformer Encoder Network来挖掘mini-batch中图像之间的关系。BatchFormer可以将每个标签的梯度传播到 mini-batch中的所有图像,这可以看作是虚拟的样本增强,从而提高了表示学习能力。此外,作者还在训练期间给BatchFormer的前后加入了权重共享分类器。最后作者展示了BatchFormer在十多个数据集上的有效性,并且在不同的任务上实现了性能显著提升。