SimCLR
SimCLR通过隐藏空间的对比损失最大化相同数据在不同增广下的一致性来学习表达。SimCLR框架有四个主要的组件,分别是:数据增广,encode网络,projection head网络和对比学习函数。
对于数据
x
x
x,从同一个数据增广族中抽取两个独立的数据增广算子(
t
∼
T
t \sim T
t∼T和
t
′
∼
T
{t}' \sim T
t′∼T),以获得两个相关的视图
x
^
i
\hat{x}_{i}
x^i和
x
^
j
\hat{x}_{j}
x^j,
x
^
i
\hat{x}_{i}
x^i和
x
^
j
\hat{x}_{j}
x^j是一对正样本,然后一个神经网络编码器
f
(
⋅
)
f\left( \cdot \right)
f(⋅)从增广的数据中提取特征
h
i
=
f
(
x
^
i
)
,
h
j
=
f
(
x
^
j
)
,
h_{i}=f\left( \hat{x}_{i} \right), h_{j}=f\left( \hat{x}_{j} \right),
hi=f(x^i),hj=f(x^j),。再然后一个小的神经网络project head
g
(
⋅
)
g\left( \cdot \right)
g(⋅)将特征映射到对比损失的空间。project head采用带有一个隐含层的MLP获取
z
i
=
g
(
h
i
)
=
W
(
2
)
σ
(
W
(
1
)
h
i
)
z_{i} = g\left( h_{i} \right) = W^{\left( 2 \right)} \sigma \left( W^{\left( 1 \right)} h_{i}\right)
zi=g(hi)=W(2)σ(W(1)hi)。
对于包含一对正样本 x ^ i \hat{x}_{i} x^i和 x ^ j \hat{x}_{j} x^j的集合 { x ^ k } \{ \hat{x}_{k} \} {x^k},对比预测任务目的是对于给定的 x ^ i \hat{x}_{i} x^i在 { x ^ } k ≠ i \{ \hat{x} \}_{k \neq i} {x^}k=i中识别出 x ^ j \hat{x}_{j} x^j。随机挑选 N N N个样本组成一个minibatch,这个minibatch中则有 2 N 2N 2N个数据样本,将其他 2 ( N − 1 ) 2\left( N - 1\right) 2(N−1)个扩增的样本作为这个minibatch中的负样本,设 s i m ( u , v ) = u T v / ∥ u ∥ ∥ v ∥ sim\left( u, v\right) = u^{T}v / \| u\| \| v\| sim(u,v)=uTv/∥u∥∥v∥表示 l 2 l_{2} l2正则化后你的 u u u和 v v v的点积,那么对一对正样本 ( i , j ) \left( i, j \right) (i,j),损失函数如下定义:
l i , j = − l o g e x p ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N 1 [ k ≠ i ] e x p ( s i m ( z i , z k ) / τ ) l_{i,j} = - log \frac{exp\left( sim \left( z_{i}, z_{j}\right) / \tau \right)}{\sum_{k=1}^{2N} \mathbb{1}_{[ k \neq i]} exp\left( sim \left( z_{i}, z_{k}\right) / \tau \right)} li,j=−log∑k=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)
最后的损失函数计算一个minibatch中的所有的正样本对,包括
(
i
,
j
)
\left( i, j \right)
(i,j)和
(
j
,
i
)
\left( j,i \right)
(j,i)。下面是SimCLR的伪代码。从伪代码中可以看出,编码器
f
(
⋅
)
f\left( \cdot \right)
f(⋅)和project head
g
(
⋅
)
g\left( \cdot \right)
g(⋅) 在训练时都会被更新参数,但是只有编码器
f
(
⋅
)
f\left( \cdot \right)
f(⋅)用于下游任务。
simCLR不采用memory bank的形式进行训练,而是加大batchsize,bacth size为8192,对于每一个正样本,将会有16382张负样本实例。增大batch size其实相当于每个minibatch时动态生成一个memory bank。论文中发现使用标准的SGD/Momentum,大batch size训练时是不稳定的,论文中采用LARS优化器。