Contrastive Learning


Representation learning \text{Representation~learning} Representation learning 的目标是为输入样本 x x x, 学习一个有效的表示 z z z

一般来说,基于多视角预测框架的 Self-Supervised Learning \text{Self-Supervised~Learning} Self-Supervised Learning,是通过预测同一张图片的不同视角来学习特征:同一个图片样本,采用不同的数据增强方法生成两个视角,模型最终学习得到相同的特征表示。理论上来说,这会产生一个问题: collapsed representation \text{collapsed~representation} collapsed representation:不管输入的图片样本是什么,模型输出的特征都是一样的。

为解决这一问题, Contrastive Learning \text{Contrastive~Learning} Contrastive Learning 引入负样本,同一图片样本使用不同数据增强(正样本对)后,模型学习得到的正样本表示之间距离尽可能相近,不同图片负样本之间表示距离尽可能远离,从而确保模型学习数据表示具有合理的区分度。

总的来收, Contrastive Learning \text{Contrastive~Learning} Contrastive Learning 的核心是通过计算样本表示之间的距离,实现正样本靠近,负样本远离

  • 输入 N N N samples \text{samples} samples,使用不同的数据增强方法,为每张图片生成两个 view:  y, y’ \text{view: ~y,~y'} view:  y, y’

  • 两个 batch \text{batch} batch 的样本表示之间计算其 cosine \text{cosine} cosine,得到相似度矩阵 A ∈ R N × N A\in\mathbb{R}^{N\times N} ARN×N,其对角线位置代表样本表示 y and y’ \text{y~and~y'} y and y’ 相之间的相似度,其余为 y \text{y} y 与其他 N − 1 N-1 N1 个负样本之间的相似度度量。

  • representation matrix \text{representation~matrix} representation matrix 的每一行做 softmax \text{softmax} softmax 分类,采用 Cross Entropy \text{Cross~Entropy} Cross Entropy 作为损失函数,得到 Contrastive Learning \text{Contrastive~Learning} Contrastive Learning L o s s Loss Loss
    L y = − log exp ( sim ( y ,   y ′ ) / τ ) ∑ i = 0 N exp ( sim ( y ,   y ′ ) / τ ) \mathcal{L}_y=-\text{log}\frac{\text{exp}(\text{sim}(y,~y')/\tau)}{\sum_{i=0}^N\text{exp}(\text{sim}(y,~y')/\tau)} Ly=logi=0Nexp(sim(y, y)/τexp(sim(y, y)/τ)

MoCo (CVPR20) \text{MoCo~(CVPR20)} MoCo (CVPR20)

Contrastive Learning \text{Contrastive~Learning} Contrastive Learning 着力于学习正负样本的有效表示。一般来说,负样本越多,学习得到的 representation \text{representation} representation 判别性越强, 能够有效防止 collapsed representation \text{collapsed~representation} collapsed representation

如何增加负样本数量?

  • 一种方式是增加 batch size \text{batch~size} batch size ,会受到 GPU \text{GPU} GPU 显存等计算资源的限制;

  • 另一种是使用 Memory bank \text{Memory~bank} Memory bank:把之前样本的 representation \text{representation} representation 保存下来。

    虽然这样解决了计算资源等的限制问题,但是 Memory bank \text{Memory~bank} Memory bank 中的 representation \text{representation} representation 是通过差异较大的 encoder \text{encoder} encoder BP \text{BP} BP 回传一次,更新一次 encoder \text{encoder} encoder)输出的,存在前后 encoder \text{encoder} encoder 不连续问题。

Kaiming He etal. \text{Kaiming~He~etal.} Kaiming He etal. 推出 MoCO (Momentum Contrast) \text{MoCO~(Momentum~Contrast)} MoCO (Momentum Contrast),采用两个 encoder \text{encoder} encoder 对输入进行编码:

  • query encoder \text{query~encoder} query encoder query \text{query} query 进行编码;
  • momentum encoder \text{momentum~encoder} momentum encoder key \text{key} key 进行编码。

用两个 encoder \text{encoder} encoder 学习得到的 representation \text{representation} representation 计算 L o s s Loss Loss
L y = − log exp ( q ⋅ k + / τ ) ∑ i = 0 N exp ( q ⋅ k i / τ ) \mathcal{L}_y=-\text{log}\frac{\text{exp}(q\cdot k_+/\tau)}{\sum_{i=0}^N\text{exp}(q\cdot k_i/\tau)} Ly=logi=0Nexp(qki/τ)exp(qk+/τ)

对每个 batch \text{batch} batch 的样本 x x x

  • 随机增强两个输入: x q ,   x k x_q,~x_k xq, xk

  • 编码器输出: q = f q ( x q ) ,   k = f k ( x k ) q=f_q(x_q),~k=f_k(x_k) q=fq(xq), k=fk(xk),并去掉 k k k 的梯度更新;

  • 计算 cosine \text{cosine} cosine

    • q q q k k k 一一对应相乘,得到正样本之间的余弦相似度: A p o s ∈ R N × 1 A_{pos}\in \mathbb{R}^{N\times 1} AposRN×1
    • q q q Memory bank \text{Memory~bank} Memory bank 中存储的 K K K 个负样本相乘,得到正负样本之间的余弦相似度: A n e g ∈ R N × K A_{neg}\in\mathbb{R}^{N\times K} AnegRN×K

    将相似度矩阵 A p o s ,   A n e g A_{pos},~A_{neg} Apos, Aneg 拼接得到 A ∈ R N × ( 1 + K ) A\in\mathbb{R}^{N\times (1+K)} ARN×(1+K),计算交叉熵损失, BP \text{BP} BP 回传,更新 f q f_q fq 的参数。

  • 动量更新:
    f k = m ∗ f k + ( 1 − m ) ∗ f q f_k=m*f_k+(1-m)*f_q fk=mfk+(1m)fq

  • 更新 Memory bank : \text{Memory~bank}: Memory bank: k k k 加入队列中,队首的旧编码出队。这样每次入队的新编码都是上一次更新后编码器的输出。

  • 方法 ( a ) (a) (a) end-to-end \text{end-to-end} end-to-end:需要保存每一个样本的 representation \text{representation} representation。对于需要大量负样本的对比学习来说,这需要占用大量显存等硬件资源。

  • 方法 ( b ) (b) (b) 是基于 Memory bank \text{Memory~bank} Memory bank:需要把 encoder \text{encoder} encoder 学习得到的各个 representation \text{representation} representation 保存到 Memory bank \text{Memory~bank} Memory bank ,然后从中采样出负样本。虽然 Memory bank \text{Memory~bank} Memory bank 中的负样本不再占用显存,但是前后 encoder \text{encoder} encoder 会存在较大差异。

  • 方法 ( c ) (c) (c) MoCo \text{MoCo} MoCo:基于 momentum encoder \text{momentum~encoder} momentum encoder 得到的 representation \text{representation} representation 保存到一个 queue \text{queue} queue 中。需要强调的是, momentum encoder \text{momentum~encoder} momentum encoder 可以看作是对 query encoder \text{query~encoder} query encoder 的平滑。

  • 实验:
    m = 0.999 m=0.999 m=0.999 m = 0.9 m=0.9 m=0.9 表现要好,在 I m a g e N e t ImageNet ImageNet 数据集上的实验结果为当时的 S O T A SOTA SOTA:

  • 总结:

    • 采用 momemtum encoder \text{momemtum~encoder} momemtum encoder 对负样本进行编码,解决了 GPU \text{GPU} GPU 显存限制的问题,同时增强了负样本 encoder \text{encoder} encoder 的一致性;
    • 最新 representation \text{representation} representation 放入一个不断更新的 queue \text{queue} queue 中,进一步增强了负样本 encoder \text{encoder} encoder 的一致性。

SimCLR (ICML20) \text{SimCLR~(ICML20)} SimCLR (ICML20)

SimCLR \text{SimCLR} SimCLR Hinton \text{Hinton} Hinton 组的 Chen Ting \text{Chen~Ting} Chen Ting 20 20 20 2 2 2 月推出的, SimCLR(4x) \text{SimCLR(4x)} SimCLR(4x) ImageNet \text{ImageNet} ImageNet 上面达到 76.5 % 76.5\% 76.5% Top 1 \text{Top~1} Top 1 Accuracy \text{Accuracy} Accuracy比当时的 SOTA \text{SOTA} SOTA 模型高了 7 7 7 个点。此外,对预训练好的模型应用 1 % 1\% 1% ImageNet \text{ImageNet} ImageNet 的标签进行 Fine-tune \text{Fine-tune} Fine-tune SimCLR \text{SimCLR} SimCLR 可以达到 85.5 % 85.5\% 85.5% Top 5 \text{Top~5} Top 5 Accuracy \text{Accuracy} Accuracy,性能再涨 10 10 10 个点。

  • S i m C L R SimCLR SimCLR 框架依然是双塔结构:
  • Input \text{Input} Input 任意一张图片样本 x x x,采用不同的数据增强方式,得到 2 2 2 张样本: x i ,   x j x_i,~x_j xi, xj
探究不同数据增强组合方式,选取了最优的;
  • 随机裁剪后再 resize \text{resize} resize 成原来大小 (Random cropping followed by resize back to the original size) \text{(Random~cropping~followed~by~resize~back~to~the~original~size)} (Random cropping followed by resize back to the original size)
  • 随机色彩失真 (Random color distortions) \text{(Random~color~distortions)} (Random color distortions)
  • 随机高斯模糊 (Random Gaussian Deblur) \text{(Random~Gaussian~Deblur)} (Random Gaussian Deblur)
  • x i ,   x j x_i,~x_j xi, xj 输入到共享参数的两个 encoder \text{encoder} encoder 编码器中,得到其输出 h i ,   h j h_i,~h_j hi, hj
  • encoder \text{encoder} encoder 之后,将 h i ,   h j h_i,~h_j hi, hj 经过一个非线性映射:
    g ( h i ) = W ( 2 ) ReLU ( W ( 1 ) h i ) g(\mathbf{h}_i)=W^{(2)}\text{ReLU}(W^{(1)\mathbf{h}_i}) g(hi)=W(2)ReLU(W(1)hi)
    得到 representation :   z i = g ( h i ) ,   z j = g ( h j ) \text{representation}:~z_i=g(h_i),~z_j=g(h_j) representation: zi=g(hi), zj=g(hj)

    研究发现 encoder \text{encoder} encoder 编码后的输出 h \mathbf h h 会保留和数据增强变换相关的信息,设置非线性层就是去掉这些信息。此外,非线性层只在无监督训练时用,迁移到其他任务时不使用。

  • 计算 l o s s loss loss :使用余弦相似度 Cosine Similarity: \text{Cosine~Similarity:} Cosine Similarity: 把计算 x i   , x j x_i~,x_j xi ,xj 的相似度转化成了计算两个 representation:  z i ,   z j \text{representation:}~z_i,~z_j representation: zi, zj 的相似度:
    s i , j = z i ⊤ ⋅ z j τ ∥ z i ∥ ⋅ ∥ z j ∥ s_{i,j}=\frac{z_i^\top\cdot z_j}{\tau\|z_i\|\cdot\|z_j\|} si,j=τzizjzizj

以前都是拿右侧数据的 N − 1 N-1 N1 个作为负例, SimCLR \text{SimCLR} SimCLR 将左侧的 N − 1 N-1 N1 个也加入了进来,总计 2 ( N − 1 ) 2(N-1) 2(N1) 个负例。另外 SimCLR \text{SimCLR} SimCLR 不采用 Memory bank \text{Memory~bank} Memory bank,而是用更大的 batch size \text{batch~size} batch size。例如, batch size=8192 \text{batch~size=8192} batch size=8192 时,有 16382 16382 16382 个负例。

MoCo v2 \text{MoCo~v2} MoCo v2

对比 SimCLR \text{SimCLR} SimCLR MoCo v2 \text{MoCo~v2} MoCo v2 做了如下改动:

  • 改进了数据增强方法;
  • 模型训练过程在 encoder \text{encoder} encoder 的输出增加了相同的非线性映射;
  • 为了对比 SimCLR \text{SimCLR} SimCLR,学习率采用相同的 Cosine \text{Cosine} Cosine 衰减。

MoCo v2 \text{MoCo~v2} MoCo v2 在更小的 batch size \text{batch~size} batch size就超过了 SimCLR \text{SimCLR} SimCLR 的表现:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值