1 概要
交叉熵损失是监督学习中应用最广泛的损失函数,度量两个分布(标签分布和经验回归分布)之间的KL散度,但是也存在对于有噪声的标签缺乏鲁棒性、可能存在差裕度(允许有余地的余度)导致泛化性能下降的问题。而大多数替代方案还不能很好地用于像ImageNet这样的大规模数据集。
许多对正则交叉熵的改进实际上是通过对loss定义的放宽进行的,特别是参考分布是轴对称的。这写改进通常具有不同的动机:比如标签平滑(Label smoothing)通过偏离轴来模糊区分正确和不正确的标签,从而在许多应用中提供了很小但是很重要的提升;在自蒸馏中,利用前几轮的“软”标签作为参考类分布进行多轮交叉熵训练;混合和相关数据增强策略通常通过线性插值创建明确的、新的训练示例,然后将相同的线性插值应用于目标标签分布,类似于软化原始交叉熵loss。用这些修改方法训练的模型显示了改进的泛化、鲁棒性和校准。
本文提出了一个新的loss,受对比loss与度量学习启发,完全去除参考分布,而只是将来自相同类的规范化嵌入强行加在一起,使得其比来自不同类的嵌入更加紧密。
具体来说,在对比学习中,核心思想是拉近某一个锚点与其正样本之间的距离,拉远锚点与该锚点其他负样本之间的距离,通常来说,一个锚点只有一个正样本,其他全视为负样本。而本文的方法认为每个锚点有许多的正样本,而不是许多负样本,并且通过标签显示样本之间的正负关联。比如下面的图,右侧是典型的对比学习方法,通常将一张原图通过数据增强得到两个子样本,这一对子样本之间构成一对正对,而与其他数据的子样本构成负对;而本文的有监督对比学习中,每个子样本可能都有很多的正对和负对。
本文构造的loss在ResNet50和ResNet200上都取得了不错的Top-1效果,在自动增强的ResNet50上取得78。8%的Top-1精度,比同样数据增强下的交叉熵loss提升了1.6%,不仅如此,还更鲁棒。
具体的Contribution如下:
- 我们提出了一个新的扩展对比损失函数,允许每个锚点有多个正对。因此,我们将对比学习适应于完全监督的设置。
- 我们表明,与交叉熵相比,这种损失使我们能够了解最先进的表示方式,从而显著提高了Top-1的准确性和鲁棒性。
- 我们的损失对超参数范围的敏感性不如交叉熵。这是一个重要的实际考虑。我们相信,这是由于我们的损失使用更自然的公式,使从同一类样本的代表被拉得更近,而不是像交叉熵一样强迫他们被拉向一个特定的目标。
- 我们分析地表明,我们的损失函数的梯度鼓励从hard positive和hard negative中学习。我们还表明,三联体损失是我们损失只有一个正极和负极被使用的一个特例。
具体来说,有监督对比学习的框架是交叉熵loss和传统对比学习的结合:
2 具体结构
2.1 表征学习框架
总的来说,有监督对比学习框架的结构类似于表征学习框架,由如下几个部分组成:
-
数据增强模块
数据增强模块 A ( ⋅ ) A(·) A(⋅)的作用是将输入图像转换为随机增强的图像 x ~ \widetilde{x} x ,对每张图像都生成两张增强的子图像,代表原始数据的不同视图。数据增强分为两个阶段:第一阶段是对数据进行随机裁剪,然后将其调整为原分辨率大小;第二阶段使用了三种不同的增强方法,具体包括:(1)自动增强,(2)随机增强,(3)Sim增强(按照顺序进行随机颜色失真和高斯模糊,并可能在序列最后进行额外的稀疏图像扭曲操作)。
-
编码器网络
编码器网络 E ( ⋅ ) E(·) E(⋅)的作用是将增强后的图像 x ~ \widetilde{x} x 映射到表征空间,每对子图像输入到同一个编码器中得到一对表征向量,本文用的是ResNet50和ResNet200,最后使用池化层得到一个2048维的表征向量。表征层使用单位超球面进行正则化。
-
投影网络
投影网络 P ( ⋅ ) P(·) P(⋅)的作用是将表征向量映射成一个最终向量 z z z进行loss的计算,本文用的是只有一个隐藏层的多层感知器,输出维度为128。同样使用单位超球面进行正则化。在训练完成后,这个网络会被一个单一线性层取代。
2.2 对比损失
本文的数据是带有标签的,采用mini batch的方法获取数据,首先从数据中随机采样 N N N个样本对,记为 { x k , y k } k = 1 , 2 , . . . , N \left\{ {x}_k,{y}_k\right\}_{k=1,2,...,N} { xk,yk}k=1,2,...,N, y k {y}_k yk是 x k {x}_k xk的标签,之后进行数据增强获得 2 N 2N 2N个数据样本 { x ~ k , y ~ k } k = 1 , 2 , . . . , 2 N \left\{\widetilde{x}_k,\widetilde{y}_k\right\}_{k=1,2,...,2N} { x