A simple framework for contrastive learning of visual representations
本文介绍了SimCLR:一个简单的视觉表征对比学习框架。我们简化了最近提出的对比性自我监督学习算法,不需要专门的架构或记忆库。为了了解是什么使对比性预测任务能够学习有用的表征,我们系统地研究了我们框架的主要组成部分。我们表明:(1)数据增强的组成在定义有效的预测任务中起着关键作用;(2)在表征和对比性损失之间引入可学习的非线性转换,大大改善了学习表征的质量;(3)与监督学习相比,对比性学习从更大的批次规模和更多的训练步骤中受益。
通过结合这些发现,我们能够在ImageNet上的自我监督和半监督学习方面大大超过以前的方法。在SimCLR学习的自监督表征上训练的线性分类器达到了76.5%的最高准确率,这比以前的技术水平提高了7%,与有监督的ResNet-50的性能相当。当只对1%的标签进行微调时,我们达到了85.8%的最高top-5准确率,超过了标签数量少100倍的AlexNet的表现。
1. Introduction
在没有人类监督的情况下学习有效的视觉表征是一个长期存在的问题。大多数主流方法属于两类中的一类:生成性或判别性。生成式方法学习生成或以其他方式模拟输入空间中的像素(Hinton等人,2006;Kingma & Welling,2013;Goodfellow等人,2014)。鉴别性方法使用与监督学习类似的目标函数来学习表征,但是训练网络来执行输入和标签都来自于未标记的数据集的前述任务。
最近,基于 latent space的对比学习的判别方法显示出巨大的前景,取得了最先进的结果(Hadsell等人,2006;Dosovitskiy等人,2014;Oord等人,2018;Bachman等人,2019)。在这项工作中,我们引入了一个简单的视觉表征对比学习的框架,我们称之为SimCLR。SimCLR不仅优于之前的工作(图1),而且更简单,既不需要专门的架构(Bachman等人,2019;Hénaff等人,2019),也不需要记忆库(Wu等人,2018;Tian等人,2019;He等人,2019;Misra & van der Maaten, 2019)。
为了了解是什么促成了良好的对比性表征学习,我们系统地研究了我们框架的主要组成部分,并表明:
-
在定义产生有效表征的对比性预测任务时,多种数据增强操作的组成是至关重要的。此外,无监督的对比性学习比有监督的学习更受益于强大的数据增强。
-
在表征和对比性损失之间引入可学习的非线性转换,可以极大地提高学习表征的质量。
-
对比性交叉熵损失的表征学习得益于normalized embeddings和适当调整的温度参数。
-
与监督学习相比,对比学习受益于更大的批次规模和更长的训练。与监督学习一样,对比学习也得益于更深更广的网络。
我们结合这些发现,在ImageNet ILSVRC-2012(Russakovsky等人,2015)上实现了自监督和半监督学习的新的最先进水平。在线性评估协议下,SimCLR实现了76.5%的top-1准确率,比之前的最先进水平(Hénaff等人,2019)有7%的相对提高。当只用1%的ImageNet标签进行微调时,SimCLR实现了85.8%的前5名准确率,相对提高了10%(Hénaff等人,2019年)。当在其他自然图像分类数据集上进行微调时,SimCLR在12个数据集中的10个表现与强监督基线(Kornblith等人,2019)相当或更好。
2. Method
2.1. The Contrastive Learning Framework
图2. 一个简单的视觉表征对比学习的框架。两个独立的数据增强算子从同一个增强系列中取样( t ∼ T t∼\mathcal T t∼T和 t ′ ∼ T t'∼\mathcal T t′∼T),并应用于每个数据实例,以获得两个相关的视图。一个编码器网络 f ( ⋅ ) f(\cdot) f(⋅)和一个projection head g ( ⋅ ) g(\cdot) g(⋅)被训练成使用对比性损失来最大化协议。训练完成后,我们扔掉projection head g ( ⋅ ) g(\cdot) g(⋅),使用编码器 f ( ⋅ ) f(\cdot) f(⋅)和representation h进行下游任务。
受最近的对比学习算法的启发(见第7节的概述),SimCLR通过潜空间中的对比损失,使同一数据实例的不同增强视图之间的一致性最大化,从而学习表征。如图2所示,这个框架包括以下四个主要部分。
-
一个随机的数据增强模块,对任何给定的数据实例进行随机transforms,导致同一实例的两个相关视图,表示为 x ~ i \widetilde x_i x i和 x ~ j \widetilde x_j x j,我们认为这是一个positive pair。在这项工作中,我们依次应用三种简单的增强:随机裁剪,然后调整到原始尺寸,随机颜色扭曲,以及随机高斯模糊。如第3节所示,随机裁剪和颜色扭曲的组合对于实现良好的性能至关重要。
-
一个神经网络基础编码器 f ( ⋅ ) f(\cdot) f(⋅),从增强的数据实例中提取representation vectors。我们的框架允许对网络架构进行各种选择,没有任何限制。我们选择简单,采用常用的ResNet(He等人,2016),得到 h i = f ( x ~ i ) = R e s N e t ( x ~ i ) h_i = f(\widetilde x_i) = ResNet(\widetilde x_i) hi=f(x i)=ResNet(x i),其中 h i ∈ R d h_i∈\mathbb R^d hi∈Rd是平均池化层之后的输出。
-
神经网络 g ( ⋅ ) g(\cdot) g(⋅)投影 h i h_i hi到 z i z_i zi, z i z_i zi被用于contrastive loss的计算。 z i = g ( h i ) = W ( 2 ) σ ( W ( 1 ) h i ) , σ z_i=g(h_i)=W^{(2)}\sigma(W^{(1)}h_i),\sigma zi=g(hi)=W(2)σ(W(1)hi),σ为ReLU函数。我们发现在 z i z_i zi计算 h i h_i hi会更有更好的效果。
-
为对比性预测任务定义的对比性损失函数。给定一组 x ~ k {\widetilde x_k} x k包括一对positive样本 x ~ i 和 x ~ j \widetilde x_i和\widetilde x_j x i和x j,对比性预测任务的目的是给定 x ~ i \widetilde x_i x i在 { x ~ k } k ≠ i \{\widetilde x_k\}_{k\ne i} {x k}k=i中识别 x ~ j \widetilde x_j x j。
我们随机抽取N个样本的minibatch,对每一个样本进行数据增强,从而得到2N个数据点。我们不对negative的样本进行明确的抽样。相反,与(Chen等人,2017)类似,我们将minibatch中的其他2(N - 1)个增强的样本视为负面例子。让 s i m ( u , v ) = u T v / ∣ ∣ u ∣ ∣ ∣ ∣ v ∣ ∣ sim(u, v) = u^Tv/||u||||v|| sim(u,v)=uTv/∣∣u∣∣∣∣v∣∣表示l2归一化的u和v之间的点积(即余弦相似度)。那么,一对positive pair(i,j)的损失函数定义为
L [ k ≠ i ] ∈ { 0 , 1 } \mathbb L_{[k\ne i]}\in\{0,1\} L[k=i]∈{0,1}是n indicator function,当 k ≠ i k\ne i k=i时为1。 τ \tau τ为温度参数。
最终的损失是在一个minibatch中所有的positive pairs中计算的,包括(i,j)和(j,i)。这个损失已经在以前的工作中使用(Sohn,2016;Wu等人,2018;Oord等人,2018);为了方便,我们称之为NT-Xent(归一化温度标度的交叉熵损失)。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZEn2RUbM-1624353609771)(002.jpg)]
2.2. Training with Large Batch Size
为了保持简单,我们不使用记忆库来训练模型(Wu等人,2018;He等人,2019)。相反,我们将训练batch大小N从256变化到8192。8192的batch大小给了我们每个来自两个augmentation views的positive pair 16382个负面例子。在使用线性学习率缩放的标准SGD/Momentum时,大batch的训练可能是不稳定的(Goyal等人,2017)。为了稳定训练,我们对所有batch大小使用LARS优化器(You等人,2017)。我们用云计算TPU训练我们的模型,根据批次大小,使用32到128个核心。
Global BN. 标准ResNets使用批量归一化(Ioffe & Szegedy,2015)。在具有数据并行性的分布式训练中,BN的均值和方差通常在每个设备上进行本地聚合。在我们的对比学习中,由于positive pairs是在同一个设备中计算的,模型可以利用local information leakage来提高预测的准确性,而不需要改进表示。我们通过在训练期间将BN的平均值和方差汇总到所有设备上来解决这个问题。其他方法包括跨设备shuffling数据实例(He等人,2019年),或用层规范代替 layer norm(Hénaff等人,2019年)。
2.3. Evaluation Protocol
Dataset and Metrics. 我们对无监督预训练(在没有标签的情况下学习编码器网络f)的大部分研究是使用ImageNet ILSVRC-2012数据集完成的(Russakovsky等人,2015)。在CIFAR-10(Krizhevsky & Hinton, 2009)上的一些额外预训练实验可以在附录B.9中找到。我们还在广泛的数据集上测试了预训练的结果,用于迁移学习。为了评估学到的表征,我们遵循广泛使用的线性评估协议(Zhang等人,2016;Oord等人,2018;Bachman等人,2019;Kolesnikov等人,2019),在冻结的基础网络之上训练一个线性分类器,并将测试精度作为表征质量的代理。除了线性评估,我们还与半监督和迁移学习的最先进技术进行比较。
络之上训练一个线性分类器,并将测试精度作为表征质量的代理。除了线性评估,我们还与半监督和迁移学习的最先进技术进行比较。
Default setting. 除非另有说明,对于数据增强,我们使用随机裁剪和调整大小(随机翻转)、颜色失真和高斯模糊(详见附录A)。我们使用ResNet-50作为基础编码器网络,并使用2层MLP投影头将表征投射到128维的潜空间。作为损失,我们使用NT-Xent,用LARS优化,学习率为4.8(=0.3×BatchSize/256),权重衰减为10-6。我们在批次大小为4096的情况下训练100个历时。3 此外,我们在前10个历时中使用线性预热,并使用余弦衰减时间表衰减学习率,不重新启动(Loshchilov & Hutter, 2016)。