对比学习,无监督学习,自监督学习基本上是一回事。
本文主要介绍 SimCLR框架:一个简单的视觉表示对比学习框架,不仅比以前的工作更出色,而且也更简单,既不需要专门的架构,也不需要储存库。
SimCLR框架优势:
- 多个数据增强组合对于定义产生有效表示的对比预测任务至关重要,尤其是随机裁剪和颜色是很好的组合。 此外,无监督的对比学习受益于比监督学习更强的数据增强。
- 在表示(h)和对比损失(在)之间引入可学习的非线性变换,大大提高了学习表示的质量。
- 具有对比交叉熵损失的表示学习受益于归一化嵌入和适当的温度参数 ττ。
- 对比学习与监督学习相比,受益于更大的批量(BatchBatch)和更长的训练时间。 与监督学习一样,对比学习也受益于更深更广的网络。
SimCLR 框架包括以下四个主要组件:
1、随机数据增强模块。随机转换任何给定的数据示例,生成同一数据示例的两个相关视图,表示并定义 xi和 xj是正对。本文组合应用三种增强:随机裁剪然后调整回原始大小(random cropping and resize back)、随机颜色失真(color distortions) 和 随机高斯模糊(random Gaussian blur)。
2、基础编码器(base encoder) f(⋅)。用于从生成的视图中提取表示向量,允许选择各种网络架构。本文选择 ResNetResNet 获得hi=f(xi)=ResNet(xi),生成的表示hi∈Rd是平均池化层(averagepoolinglayer)后的输出。
3、投影头(projection head) g(·)将表示映射到应用对比损失的空间。 本文使用一个带有一个隐藏层的 MLP 来获得 zi=g(hi)=w(2)σ(w(1)hi) 其中 σ是一个ReLU 非线性函数。此外,发现在zi而非 hi上定义对比损失是有益的。
4、对比损失函数(contrastive loss function)。 给定 batchbatch 中一组生成的视图 {xk},其中包括一对正例 xi 和 xj ,对比预测任务旨在对给定 xi 识别 {xj}k≠i 中的xj 。
随机抽取 N 个样本的小批量样本,并在从小批量样本上生成增强视图,从而产生 2N 个数据点。 本文无明确地指定负例,而是给定一个正对(positivepair),将小批量中的其他 2N−2个增强示例视为负示例。本文定义相似度为余弦相似度。则一对正对 (i,j)的损失函数定义为:
其中 1[k≠i]∈{0,1} 是指示函数,当 kk≠i 为 1 。τ是温度参数。最终损失是在小批量中计算所有正对 (i,j)和 (j,i)。为方便起见,将其称为 NT−Xent(归一化温度标度交叉熵损失)。
算法流程
图解算法流程:
Step1:随机数据增强模块
首先,原始图像数据集生成若干大小为 N 的 batch。这里假设取一批大小为 N=2 的 batch。本文使用 8192 的大 batch。
定义随机数据增强函数 T
对于 batch中的每一幅图像,使用随机数据增强函数 T 得到一对view。对 batch 为 2 的情况,得到 2N=4 张图像。
Step2:基础编码器(base encoder) f(⋅)
对增强过的图像通过一个编码器来获得图像表示。所使用的编码器是通用的,可与其他架构替换。下面的两个编码器共享权值,得到表示vector hi和hj。
在本文中,作者使用 ResNet−50 架构作为编码器。输出是一个2048 维的向量 h。
Step:投影头(projection head) g(·)将表示映射到应用对比损失的空间。
本文使用一个带有一个隐藏层的 MLP 来获得 zi=g(hi)=w(2)σ(w(1)hi) 其中 σ是一个 ReLU非线性函数。
Step4:使用对比损失函数进行模型调优。
对于 batchbatch 中的每个增强过的图像通过基础编码器 f(⋅)f(⋅),得到嵌入向量 zz。
使用嵌入向量zizi,计算损失的步骤如下:
a. 计算余弦相似性
用余弦相似度计算图像的两个增强的图像之间的相似度。对于两个增强的图像 xixi 和 xjxj,在其投影表示 zizi 和 zjzj 上计算余弦相似度。
使用上述公式计算 batchbatch 中每个增强图像之间的两两余弦相似度。如图所示,在理想情况下,增强后的猫的图像之间的相似度会很高,而猫和大象图像之间的相似度会较低。
b. 损失的计算
SimCLRSimCLR使用了一种对比损失,称为“NT−XentNT−Xent损失”(归一化温度-尺度交叉熵损失)。工作步骤如下:
首先,将 batchbatch 的增强对逐个取出。
接下来,我们使用和 softmaxsoftmax 函数原理相似的函数来得到这两个图像相似的概率。
这种 softmax 计算等效于获得第二张增强猫图像与该对中的第一张猫图像最相似的概率。批次中的所有剩余图像都被采样为不同的图像(负对)。 因此,我们不需要像 InstDisc、MoCo 或 PIRL 等以前的方法那样需要专门的架构、存储库或队列。
然后,取上述计算的负对数来计算这一对图像的损失。
图像位置互换,再次计算同一对图像的损失。
计算 BatchsizeN=2BatchsizeN=2 的所有配对的损失并取平均值。
最后,更新网络 f(⋅)和 g(⋅) 以及最小化 L。
下游任务
一旦SimCLR模型被训练在对比学习任务上,它就可以用于迁移学习。为此,使用来自编码器的表示,而不是从投影头获得的表示。这些表示可以用于像ImageNet分类这样的下游任务。
为了稳定训练,我们对所有批大小使用LARS优化器,而不是SGD。学习速率为4.8 (= 0.3 × BatchSize/256),权重衰减为10−6。
引用:
(52条消息) 图解SimCLR框架,用对比学习得到一个好的视觉预训练模型_ronghuaiyang的博客-CSDN博客