对比学习(Contrastive Learning)
基本概念
传统任务中,对于预训练模型,在训练中,大多都是用分类去做的。但是这种方法具有一定的局限性,例如,在做一个分类中,我们最初会给模型确定一个类别值,即要分成多少类别,而这种设定,恰恰可能会给我们的模型强加一个束缚,即在训练中,每次都是不断地提取关于这些指定类别的特征,而不去关注其他特征,限制了我们模型和数据的最大潜力——我们给定了标签,限制了模型就干什么,而我们应该在训练模型中,不要让他被标签所束缚,模型的潜力应该由他自己挖掘
对比学习(自监督学习)——不需要准备标签,使得我们的模型具有特征提取的能力
通俗来讲,对比学习,就是要判断异同——相同的就是正例,不同的就是负例,让模型学习其中的规律
标签的定义
对比学习中,最重要的是不进行标签定义
而是要通过模型自己去学习数据中的规律,自己进行标签的定义
如何表示特征
假设我们已经有了样本,为了区分正例与负例,我们应该计算两个样本之间的相似度
以输入数据为图片为例:
- 将图像做成向量(
encoder
过程) - 定义相似度的函数(余弦相似度)
- 计算正负样本之间的距离
- 正样本之间越近越好
- 负样本之间越远越好
SimCLR Framework
在对比学习论文中,SimCLR比较具有代表性,因此一下笔记都是基于SimCLR模型进行介绍对比学习
整体模型
输入数据后
首先,对数据进行数据增强(例如裁剪、调整大小、旋转、重新着色等),对输入数据进行随机数据增强,得到两个样本 x i x_i xi 和 x j x_j xj ,这两个样本属于正样本(来自于同一张图片)
然后,进行**encoder
过程**,对这两个数据进行编码,转化为向量形式
h
i
h_i
hi 和
h
j
h_j
hj
最后通过一个全连接层(Dense + Relu + Dense),的到最后的两个输出 z i z_i zi 和 z j z_j zj,并且计算他们之间的相似度
而我们的目标函数为最小化两个输出之间的相似度
对比学习给我们提供的是特征提取的能力,能够提取更具有代表性的特征,在下游任务中的表现更好
注意,是 h i h_i hi 和 h j h_j hj 应用于下游任务,而不是经过全连接层加工之后的数据应用于下游任务。因为使用原生数据,可以使得模型的泛化能力更强
同时,在模型的学习中,
- 为了防止模型坍塌,还要注意正样本和负样本要同时学习,
- 并且要求对比学习的训练过程中时,
batch
设置的比较大,能够更好的提高模型的学习能力
batch
数据输入
假设一个批量为2,输入数据后,会进行随机数据增强:
如上图中,红色的连线是正样本,蓝色的连线是负样本,他们是互为的关系
特征提取
对于特征提取(encoder
部分),可以使用transform
模型进行
并且 encoder
部分可大可小,一般而言,使用的模型越大,效果可能越好
基本思想
其实就是同类越相似(一般使用余弦相似度来定义)
一般而言,对于上图右侧的矩阵,要去掉对角线上的元素
因为对角线元素代表自己跟自己,如果我们给他一个明确的定义来计算,那么可能会使得同类之间学习的效果较差(浅绿色部分),因此需要去除对角线
损失函数的设计
l
i
,
j
=
−
l
o
g
e
x
p
(
s
i
m
(
z
i
,
z
j
)
/
τ
)
∑
k
=
1
2
N
1
[
k
≠
i
]
e
x
p
(
s
i
m
(
z
i
,
z
k
)
/
τ
)
l_{i,j} = -log\frac{exp(sim(z_i, z_j) / \tau)}{\sum_{k=1}^{2N}\mathbb{1}_{[k\neq i]}exp(sim(z_i, z_k) / \tau)}
li,j=−log∑k=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)
分子表示同类,分母表示不同类。
对于 除以 τ \tau τ,由于该数一般小于1,能够起到增大负例之间的差异、正例之间的相似
数据增强技术
在论文中提到,数据增强逐渐成为核心技术
数据增强的方法越花里胡哨可能效果越好,因为越花里胡哨,会使得数据的变化越大,就需要学习更多的东西
SimCLR V2
相比于SimCLR第一个版本,把模型做的更大,加入了蒸馏技术
Multiview Coding多视角
一方面,为了增加训练的难度,另一方面为了更好的计法模型的潜能,通常还需要对数据做一个多视角任务
多视角相当于对数据进行变化,从多个不同的视角进行变换
如上图,我们分别对图片进行一些变化(分割,深度等),但无论我们做什么变化,都是基于原图进行变化的,因此他们仍然要属于同一个类别
BYOL
不需要将负样本加入训练,就可以进行对比学习
整体流程为:
- 输入数据,依次进行编码、全连接层,最终得到两个输出
- 然后将其中一个输出作为目标值(如上图下面的输出),另一个输出作为预测值,类似于做一个回归的任务进行预测
但是,有个外国大佬对其进行过质疑,因此在该模型中的全连接层中(MLP
),使用了batch normalization
,即使用整个批量进行了标准化过程,因此会导致数据间接的利用了负样本的一些特征信息,而不是真正的完全脱离负样本进行训练