近期对比学习在 NLP \text{NLP} NLP领域取得了不错的成绩,例如句嵌入方法 SimCSE [ 1 ] \text{SimCSE}^{[1]} SimCSE[1]和短文本聚类方法 SCCL [ 2 ] \text{SCCL}^{[2]} SCCL[2]。为了能更好的理解近期的进展,期望通过一系列相关的文章来循序渐进的介绍其中的技术和概念。本文就作为该系列的第一篇文章吧~
一、 SimCLR \text{SimCLR} SimCLR简介
- 原始论文:A Simple Framework for Contrastive Learning of Visual Representations
- 核心思想:
-
- 将一个样本 x x x数据增强为两个不同个样本 x ~ i \tilde{x}_i x~i和 x ~ j \tilde{x}_j x~j;
- 拉近样本 x ~ i \tilde{x}_i x~i和 x ~ j \tilde{x}_j x~j的距离,并拉远它们和其他样本的距离;
-
二、 SimCLR \text{SimCLR} SimCLR框架
SimCLR \text{SimCLR} SimCLR是一个对比学习的框架,其结构如图1所示,主要包含四个组件:
1. 数据增强模块
该模块会为一个样本随机生成两个增强样本 x ~ i \tilde{x}_i x~i和 x ~ j \tilde{x}_j x~j,这两个样本组成了一个正样本对 ( x ~ i , x ~ j ) (\tilde{x}_i,\tilde{x}_j) (x~i,x~j)。
- 论文主要是针对图像的。因此,采用的数据增强方式包括:裁剪、颜色失真、高斯模糊;
2. 编码器
编码器 f ( ⋅ ) f(\cdot) f(⋅)的作用是将增强样本转换为向量表示, h i = f ( x ~ i ) \textbf{h}_i=f(\tilde{x}_i) hi=f(x~i);
- 论文选择 ResNet \text{ResNet} ResNet作为编码器, h i = f ( x ~ i ) = ResNet ( x ~ i ) \textbf{h}_i=f(\tilde{x}_i)=\text{ResNet}(\tilde{x}_i) hi=f(x~i)=ResNet(x~i);
3. 投影头(Projection head)
投影头 g ( ⋅ ) g(\cdot) g(⋅)是一个小型神经网络,其作用是将样本的向量表示映射至可以对比的空间中(也就是适合Loss计算的表示空间);
- 论文使用单层全连接神经网络作为投影头,即 z i = g ( h i ) = W ( 2 ) σ ( W ( 1 ) h i ) z_i=g(\textbf{h}_i)=W^{(2)}\sigma(W^{(1)}\textbf{h}_i) zi=g(hi)=W(2)σ(W(1)hi), σ \sigma σ是 ReLU \text{ReLU} ReLU激活函数;
4. 对比损失函数
对比损失函数 l \mathcal{l} l,其作用是:在一个包含正样本对 ( x ~ i , x ~ j ) (\tilde{x}_i,\tilde{x}_j) (x~i,x~j)的集合 { x ~ k } \{\tilde{x}_k\} {x~k},给定样本 x ~ i \tilde{x}_i x~i,从 { x ~ k } k ≠ i \{\tilde{x}_k\}_{k\neq i} {x~k}k=i中确定出 x ~ j \tilde{x}_j x~j;
三、框架的实现
上面描述了 SimCLR \text{SimCLR} SimCLR框架,本小节则是该框架的一个具体实现。
1. 损失函数 NT-Xent \text{NT-Xent} NT-Xent
- 随机采样 N N N个样本作为 minibatch \text{minibatch} minibatch,并通过数据增强生成 2 N 2N 2N个样本。这里将正样本对以外 2 ( N − 1 ) 2(N-1) 2(N−1)个样本当做负样本;
- 向量相似度计算方式为: sim ( u , v ) = u ⊤ v / ∥ u ∥ ∥ v ∥ \text{sim}(u,v)=u^\top v/\Vert u\Vert\Vert v\Vert sim(u,v)=u⊤v/∥u∥∥v∥;
- 正样本对 ( i , j ) (i,j) (i,j)的损失函数
l
i
,
j
=
−
log
exp(sim(
z
i
,
z
j
)
/
τ
)
∑
k
=
1
2
N
1
k
≠
i
exp(sim(
z
i
,
z
k
)
/
τ
)
\mathcal{l}_{i,j} = -\text{log}\frac{\text{exp(sim(}z_i,z_j)/\tau)}{\sum_{k=1}^{2N}1_{k\neq i}\text{exp(sim(}z_i,z_k)/\tau)}
li,j=−log∑k=12N1k=iexp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)
其中,
1
[
k
≠
i
]
∈
{
0
,
1
}
1_{[k\neq i]}\in\{0,1\}
1[k=i]∈{0,1}是指示函数,
τ
\tau
τ是温度(temperature)参数;
- 同一个 minibatch \text{minibatch} minibatch中所有正样本对的损失之和为最终的loss,称这个loss为 NT-Xent \text{NT-Xent} NT-Xent。
2. 完整的算法描述
输入:batch size N N N,常量 τ \tau τ,结构 f , g , T f,g,\mathcal{T} f,g,T;
for 采样的minibatch { x k } k = 1 N \{x_k\}_{k=1}^N {xk}k=1N do
for all k ∈ { 1 , … , N } k\in\{1,\dots,N\} k∈{1,…,N} do
随机选择两种数据增强函数 t ∼ T , t ′ ∼ T t\sim\mathcal{T},t'\sim\mathcal{T} t∼T,t′∼T
# 第一个数据增强
x ~ 2 k − 1 = t ( x k ) \tilde{x}_{2k-1}=t(x_k) x~2k−1=t(xk)
h 2 k − 1 = f ( x ~ 2 k − 1 ) h_{2k-1}=f(\tilde{x}_{2k-1}) h2k−1=f(x~2k−1) # 表示
z 2 k − 1 = g ( h 2 k − 1 ) z_{2k-1}=g(h_{2k-1}) z2k−1=g(h2k−1) # 投影
# 第二个数据增强
x ~ 2 k = t ′ ( x k ) \tilde{x}_{2k}=t'(x_k) x~2k=t′(xk)
h 2 k = f ( x ~ 2 k − 1 ) h_{2k}=f(\tilde{x}_{2k-1}) h2k=f(x~2k−1) # 表示
z 2 k = g ( h 2 k − 1 ) z_{2k}=g(h_{2k-1}) z2k=g(h2k−1) # 投影
end for
for all i ∈ { 1 , … , 2 N } and j ∈ { 1 , … , 2 N } i\in\{1,\dots,2N\}\text{ and } j\in\{1,\dots,2N\} i∈{1,…,2N} and j∈{1,…,2N} do
s i , j = z i z j / ( ∣ ∣ z i ∣ ∣ ∣ ∣ z j ∣ ∣ ) s_{i,j}=z_iz_j/(||z_i||||z_j||) si,j=zizj/(∣∣zi∣∣∣∣zj∣∣)
end for
定义 l ( i , j ) \mathcal{l}(i,j) l(i,j)为 l i , j = − log exp(sim( z i , z j ) / τ ) ∑ k = 1 2 N 1 k ≠ i exp(sim( z i , z k ) / τ ) \mathcal{l}_{i,j} = -\text{log}\frac{\text{exp(sim(}z_i,z_j)/\tau)}{\sum_{k=1}^{2N}1_{k\neq i}\text{exp(sim(}z_i,z_k)/\tau)} li,j=−log∑k=12N1k=iexp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)
L = 1 2 N ∑ k = 1 N [ l ( 2 k − 1 , 2 k ) + l ( 2 k , 2 k − 1 ) ] \mathcal{L}=\frac{1}{2N}\sum_{k=1}^N[\mathcal{l}(2k-1,2k)+\mathcal{l}(2k,2k-1)] L=2N1∑k=1N[l(2k−1,2k)+l(2k,2k−1)]
通过最小化 L \mathcal{L} L来更新网络 f f f和 g g g
end for
return 返回编码网络 f ( ⋅ ) f(\cdot) f(⋅),并丢弃 g ( ⋅ ) g(\cdot) g(⋅)
3. 训练细节
- 为了不使用memory bank,将batch size从256增大至8192;
- 由于 SGD \text{SGD} SGD在大batch size上不稳定,使用LARS进行训练;
四、分析
-
数据增强操作的组合对于学习好的向量表示至关重要
上图是不同种数据增强方式间组合带来的影响,对角线表示单个一种数据增强方法。可以发现,对角线的颜色都比较深,也就是说单一的数据增强方式效果并不好。两两组合的数据增强方式效果更佳。
-
相较于有监督学习,数据增强对对比学习更加有效
上表时数据增强程度对有监督学习(Supervised)和对比学习(SimCLR)的影响。可以发现,数据增强对“对比学习”影响更大。
-
模型越大、对比学习效果越好
上图中红色的点是对比学习的效果,随着模型规模的增大,效果也越来越好;
-
非线性投影头能改善向量表示的质量
上图中,非线性投影头优于线性投影头,线性投影头优于不进行投影;
-
合适的温度参数能够帮助模型学习到更难的负样本
观察上表, l2 norm \text{l2 norm} l2 norm是有效的,而是适当大小的 τ \tau τ也有助于模型的表现;
-
大batch size和长的训练时间也有益于对比学习
观察上图,大的batch size和较大的epoch有助于模型的表现;