对比学习学习笔记

什么是对比学习?

​ 官方解释:对比学习是一种机器学习的方法,该方法无需标注,通过告诉训练模型哪些数据点是相似的还是有差异的,来学习数据集中的通用特征。

​ 简单的通过案例来解释一下,下面是猫和狗的对比图:
图1

​ 即使是对于一个认知还未健全的小孩子来说,也很容易能够区分出,上图中两只猫是相似的,而下图中的,猫和狗是不同的。

​ 比如,我们很容易看到两只猫都有尖尖的耳朵,而狗只有耷拉下来的耳朵。或者我们可以将猫的平平的鼻子和狗凸出来的鼻子进行对比。

本质上,对比学习让我们机器学习的模型做同样的事情,就是通过判断数据对是相似的还是不同的来学习高维的数据特征,这个学习的过程甚至可以在一些分割或者分类的任务之前

为什么它如此强大?

​ 首先,我们可以在没有标注数据的情况下训练一个模型,也就是我们所说的自监督学习。其实在真实场景下,我们可能并没有所有数据的标签,特别是对于医学数据,我们获得准确的标签,需要更高的门槛。而使用对比学习的方法,即使只有一小部分标注的数据,也可以在一定程度上提高模型的性能。在了解了什么是对比学习,和为什么它这么有效之后,让我们来看下对比学习是怎样的一个原理吧。

对比学习原理

​ 在本文我们以simCLRv2为例进行阐述。

​ 整个的过程分为以下三个步骤:

​ 1.对于数据集中的每一个图片,我们可以采用两种数据增强的方法结合(比如,裁剪+改变尺寸+改变颜色,裁剪+改变颜色,改变大小+改变颜色)。我们的目的就是让模型认为这两个图片是相似的,因为他们本来就只是同一张照片的不同形式。

图2

​ 2.我们将这两张照片喂给深度学习模型(如ResNet)来将每一种图片生成表表达向量,目标就是用来训练模型以获得相似图像的相似表达。

图3

​ 3.最后,我们尽量去最小化对比损失函数来最大化这两个表达向量的相似度

图4

​ 训练一段时间后,模型将会学习到这两张猫的图片应该有相似的表达,并且猫的表达向量和狗的是不同的。这就说明,这个模型可以在不知道这个图像是什么的情况下,区分不同的类型的图片。

​ 我们可以对对比学习进行更深的分解成三个重要的部分:数据增广,编码,最小化损失

数据增广

图5

​ 我们随机组合以下的增广方法:裁剪,改变尺寸,色彩变形,灰度化,我们在每一个batch里,对每一个图片做两次变化,来创建一对正样本对。

编码

​ 我们采用大型的CNN网络,我们可以简单的看作是一个函数,h=f(x),其中x是我们的增广图像,该函数用于将图像编码为表达向量。

图6

​ CNN的输出将输入到一些Dense层中,这些层也被称为投影头(z = g(h)),将数据转换为其他的空间。这个提取的步骤通过实验证明可以提高性能。

​ 通过将图像压缩成一个隐空间表达,这个模型可以学习图像的高维特征,事实上,当我们持续训练这个模型来最大化相似图像的相似度,我们可以想象这个模型就是在隐空间对相似的点集进行聚类。

​ 比如说,猫的表达向量会聚集在一起,而远离狗的向量表达,这就是我们想要训练的效果。

损失最小化和表达

​ 已知有两个向量,z1、z2,我们需要一种方法来量化他们之间的相似度,由于我们是对两个向量进行比较,比较常用的是余弦相似度,基于空间向量之间的夹角来进行量化。

图7

​ 从逻辑上将,当两个向量越接近(夹角越小),它们更相似。因此如果我们以余弦作为量化指标,当角度越接近于0,两个向量则有更高的相似度,反之亦然。

(此过程在实现时可以采用以下两步:

1.对特征矩阵做正则化

2.将矩阵相乘 M * M^T

​ 除了余弦相似度,我们同时需要一个可以最小化的损失函数,其中一个选择就是NTXent(Normalized Temperature-Scaled Cross-Entropy Loss 正则化温度值交叉熵)。

​ 下面我们开始学习损失函数的定义

​ 我们首先计算两个增强图像的相似概率,采用了softmax的方式求得概率,下图可以形象的表示:

图8

​ 注意分母是e^{Similarity}(所有的图像对,包括负数对)。负数对取自增强图像产生的图像对,所有的其他图像都来自同一个bacth。我们采用-log函数来实现最小化最终的损失函数,即最大化两个图像的相似度。

图9

​ 最后,我们计算在一个batch中所有对的loss,求平均(设batch size N = 2)

图10

​ 基于这个Loss,编码器和投影头部表达不断的优化,这个表达将相似的图像放在更近的位置。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值