Momentum contrast for unsupervised visual representation learning
无监督-对比学习
先介绍一些对比学习的一些概念和工作
引用leCun在NeurIPS 2016的一张图,机器学习领域,强化学习是蛋糕上的樱桃,监督学习是蛋糕上的糖霜,只有无监督学习才是蛋糕的本质
模型不需要识别出图片的精确类别,模型只需要判断出哪些图片是类似的,哪些图片是差异巨大的。比方三张图片(男人,女人,狗),网络能抽取出特征f1 f2 f3,同时能做到拉近f1 和 f2 同时远离f3。
但是这里面有一种隐藏标签,模型的标签还是要告诉哪些是类似的哪些是不同的。
在视觉领域会设计一些代理任务(pretext task),定义一个规则,代理任务会提供一个监督信号,这也是所谓的自监督训练。
最广为应用的代理任务,instance discrimination的步骤
- 从一个未知数据集抽取一张图片f1
- 随机裁剪和数据增广得到另外两张图 f11,f12,其中f111叫anchor,f12叫poitive,这两张图片看起来很不一样,但是语义信息不变(来自于同一张照片,比如黑天的狗和白天的狗?),这两张图片就作为正样本,其他所有图片都作为负样本neagitve,即一张图片自成一类
Abstract
- 一个动态字典,不需要梯度回传,存储负样本而且可以做的很大
- 移动平均编码器,让字典数据尽量均衡
- linear protocal任务,在预训练好了backbone之后将其freeze,然后去训练其他下游模型
- 填补上了有监督和无监督之间的gap
Introduction
- MoCo是无监督里第一个可以在视觉众多下游任务中取得优异成绩的无监督预训练模型,视觉领域的大量的模型都是有监督预训练的,而nlp中的bert等无监督预模型取得了很好的成果,区别在于信号空间,图像时连续的,文字时离散的。能不能构建动态字典,对是否能对无监督学习建模是至关重要的。
- moco里用了两组抽取特征的编码器,如图1,分别用于得到anchor feature f11和 positive feature f12+neagtive feature f2,因为positive/negative都是相对于anchor而言的,为了保持特征一致性 PN采用了相同的编码器。
- 有关动态字典
The "keys" (tokens) in the dictionary are sampled from data (e.g images or pathes) and are represented by an encoder network
我们把 f11 f12 ... 与 f2 f3... 是为一个字典所有 keys,其中f12是正样本key+,把f11是为字典的query,那么对比学习就转化为了一个字典查询的问题,query 只需找到与自己最相似的key。
moco对这个字典有两个要求:
(1)large :这个字典才能保证可以表示出图形最本质的特征,可以用队列数据结构来解决显卡显存的问题,队列也被称为FIFO的数据结构,也就是先进先出,动态字典是不断更新的,当新数据加入队列时,最老的那一批数据需要从队列中出去。
(2)consisent,字典的key需要用相同或相似的编码器得到,才能保证对比一致。可以用动量的方法解决这个一致性的问题,把抽取队列编码器的参数θk 的动量momutunm选的很大,这样θk就会缓慢的更新,尽量保证一致性
而目前的对比学习方法或多或少的会被其中的一条限制
5 Conclusion
moco可以在中型数据集imagenet和大型数据集Instagram都能取得很好的结果,可以成功的把无监督学习和有监督学习的坑填平
作者注意到从1M数据迁移到10亿数据只能带来0.几个点的提升,他分析是大规模的数据集没有被很好的利用起来,所以也许一些新的pretext tasks 像masked auto-encoding 能起到很好的效果(就是利用nlp的完型填空,这就是MAE,kaming那篇,看来MAE也要看一下)
2 Related work
无监督学习/自监督学习的工作主要包括两个方面:
1)pretext task
denoising auto-encoder : 重建整张图
context autuo-encoder :重建某个patch
colorization : 给图片上色
("exemplar")image : 给图片生成伪标签,这些图片都属于同一个类,一种数据增广
patch orderings : (九宫格方法)一张图片分成3*3个patch,选定其中一个参考patch,再选定一张测试patch,模型给出测试patch在参考patch的哪个方位。
对比学习可以和其中某个代理任务联用。
之前两个比较重要的工作:
CMC使用的是上下文预测,与context auto-encoder类似
CPC使用了同一对象的不同视角,与colorization类似
2)loss function:根据生成式网络和判别式网络选择不同的loss function。
对比学习的目标函数不同于一般的无监督学习目标函数,对比学习的训练目标是随着学习的过程不停改变的,主要的目的在于区分相似的目标和不相似的目标。
此外还有对抗性的目标函数主要用于GAN网络的图片生成。主要是用于衡量两个概率分布之间的差距。
3 Method
3.1
对比学习可以看成是训练一个编码器,去做字典查找的工作。
损失函数的设计方面,我们希望,训练好的模型提取了query和key+的特征相似时,loss尽可能的小,在输入query 和负样本key时,loss也尽可能的小,因为这样达到了我们最初的设计目的。反之亦然。
moco使用了NCEloss,NCEloss是根据cross entropy改进而来的,在说CEloss之前,首先回忆一下 softmax
CEloss
可以注意到CEloss就是在softmax的前面套入了-log,这是因为softmax是用来归一化类别的得到了一个概率分布p,而CEloss与KL散度类似,都是用来估计概率分布之间的差距,所以softmax很自然的就介入到CEloss中。
上面的两个都是CEloss,第二个是特指多分类的CEloss,其实两者的格式并不冲突,因为yic只有0和1两种可能,如果类别采用了one-hot编码的形式,那么两者就一致了。
但是这种格式却不能直接应用到moco中,因为类别种类j会变得很大(因为moco的训练每张图片自成一类),在上百万类别下的softmax就不管用了。
NCE(noise contrastivee loss function)
NCE将数据分成正样本和噪声样本,所以以前的多分类问题就变成了2分类问题,但是这个计算复杂度的问题还没有解决。
这里使用取近似的办法来解决负样本数太多,导致的计算复杂度过高的问题,就是抽取负样本的一部分进行估计,抽取负样本数越多,近似效果越好(所以这个字典要足够large),反之亦然。
InfoNCEloss
作者认为仅仅将数据归纳为2分类不太全面,还是要多分类的
上式的qk得到了一个logits,和softmax得出的结果近似,τ代表了一个温度超参数,用于调整logits的分布,K= 抽样负样本的数量(字典里负样本的数量)+ 1个正样本的数量 = 字典里所有样本的数量。
模型的输入和模型本身的设计
模型的输入xq和xk可以是images ; patches
模型本身fq ;fk,可以是完全一致的;部分参数共享的 ;完全不一致的
而这都是由代理任务决定的。
3.2
Momentum update
由于字典的长度很长,不能用梯度回传的方式更新这个fk,简单的方式就是每次训练的iteration之后,将fq的值赋给fk,但是这种方法取得的结果并不好,所以要采用动量的方式更新(公式2),第一代θk的初始化是由θq来做的,后面的更新主要靠自己。
之前的工作往往受限于字典的大小或字典的一致性
如图
(a) 在端到端的学习框架中,
encoder q和 encoder k 采用了相同的参数,由于他们来自于同一个mini batch size 所以数据的分布是一致的,同时端到端的学习允许了对encoder k的梯度回传,可以使得K有很好的一致性,但是在端到端的学习方案中,mini batch 的大小就等于字典的大小,由于硬件不能承担太大的batch size,使得字典的大小成为了一个问题。谷歌由于拥有充足的算力,使得训练SimCLR这种模型的batch可以达到8192,产生的1万多的负样本池可以充足的训练对比学习,所以SimCLR采用了端到端的学习方式
(b)memory bank,
memory bank 没有编码器,所以采用了将全部数据集的特征都存入bank中,以采样的方式得到负样本,要知道存储key所消耗的内存要远远小于存储图片所消耗的内存。有得必有失,memeory bank的具体做法是,
从bank中采样n 个 key 去组成dictionary和query做一次训练和模型更新,
然后使用更新之后的encoder,去重新计算一下此次使用的这n个样本(图片)的key值
将这些新的key值重新放回到么memory bank中,替代原来旧的key值。
这样无形之中就加重了负样本特征的不一致性问题。
(c) moco
end-to-end 的计算和更新都在线上,占用显存
memory bank 正样本encoder的计算和更新在线上,负样本没有encoder 不存在计算和更新的问题
moco的正样本encoder计算和更新在线上,负样本encoder的计算和更新在线下。
总体来说moco的做法和memory bank更加类似,都是只有一个编码器是使用梯度回传来更新的,字典都采用了额外的数据结构进行存储从而于batch size剥离开。
配合伪代码理解
fq和fk是两个encoder
x_q是query数据,x_k 是key数据,维度都是256*128,batch size = 128
bmm和mm都是计算logits
l_pos的特征维度就变成了256*1
queue就是负样本字典 维度是128*65536
l_neg的特征维度是256*65536
所有的logits 拼接起来 logitis的维度是256*65537
ground_truth 设计为全0,即第0类,按照伪代码的实现方式,正样本永远位于第一个,如果找对了那个key,那么在分类任务里得到的类别就是类别0
f_k.params 队列更新
enqueue 队列进
dequeue队列出
———————————————————————————————————————————trick
shuffling BN :在transformer中引起了很大争论的就是BN,大部分现在用的都LN