有监督对比学习-Supervised Contrastive Learning
Background
交叉熵损失在深度分类的监督学习中是被应用的最广泛的损失函数。但是许多工作表明这种损失的缺点,如:缺乏对噪音标签的鲁棒性,弱边界的区分的可能性(可能指的是分类问题中的边界模糊或区分度不高的情况,原文-the possiblity of poor margins),导致泛化能力的降低。然而,在实际中,大部分提出可供选择的损失函数在大数据集上并没有比交叉熵函数工作的好,证据就是最好的结果的方法仍然在继续使用交叉熵损失函数。
在近些年,对比学习工作的复兴导致了自监督表示学习的重大进展。这些工作的通用想法如下:将anchor(原样本)和一个“正样本”(做数据增强得到的样本)的表示尽可能映射在同一向量空间,将许多负样本与anchor的向量表示拉远。因为没有标签可以利用,正样本对由anchor和它的数据增强组成,负样本由随机选择的minibatch中的样本和anchor组成。下图可以描述:
Supervised Contrastive Learning
本文介绍的有监督对比学习不同自监督学习,正负实例的选择策略也不同:正样本是从与锚点相同类别的样本中提取的而不是像自监督学习中所做的是锚点的数据增强。每个锚点使用多个正实例和多个负实例,无需进行困难负样本的选择探究。
这张图很好地描绘了监督对比学习和自监督对比学习的区别。在自监督对比学习中,通常只有一个正样本对,但是相同类的图片不被视为正样本(如上图中的狗,黑白狗被视为原anchor的负样本),在一个batch中其他图片都被视为负样本对。而监督对比学习利用了标签信息,将同类的样本也视为正样本,避免了自监督学习将同类判定为负样本的错误,提升了学习的效率。
损失函数设计
我们先来看数据集的格式。随机抽取一个batch,分为N个样本对,记为 { x k , y k } k = 1 , … , N \{x_k , y_k \}_{k=1,\ldots,N} {xk,yk}k=1,…,N , y k y_k yk是 x k x_k xk的标签,之后对这个batch进行数据增强获得2N个数据样本 { x ~ ι , y ~ ι } ι = 1 , 2 , . . . , 2 N \{\tilde{x}_\iota ,\tilde{y }_\iota \}_{\iota = 1 , 2 , . . . , 2 N} {x~ι,y~ι}ι=1,2,...,2N,其中, x ~ 2 k \tilde{x}_{2k} x~2k和 x ~ 2 k − 1 \tilde{x}_{2k −1} x~2k−1是随机增强方法得到的相应数据对,在数据增强过程中,标签 y k y_k yk信息始终不会改变。
self-Contrastive Learning Loss
我们首先来看自监督对比学习。通过自监督对比学习,我们才能更好的理解监督对比学习的优势。自监督对比学习的损失函数公式如下:
L
s
e
l
f
=
∑
i
∈
I
L
i
s
e
l
f
=
−
∑
i
∈
I
log
exp
(
z
i
⋅
z
j
(
i
)
/
τ
)
∑
a
∈
A
(
i
)
exp
(
z
i
∙
z
a
/
τ
)
(
1
)
\mathcal{L}^{self}=\sum_{i\in I}\mathcal{L}_i^{self}=-\sum_{i\in I}\log\frac{\exp\left(\boldsymbol{z}_i\cdot\boldsymbol{z}_{j(i)}/\tau\right)}{\sum_{a\in A(i)}\exp\left(\boldsymbol{z}_i\bullet\boldsymbol{z}_a/\tau\right)}\quad\quad\quad(1)
Lself=i∈I∑Liself=−i∈I∑log∑a∈A(i)exp(zi∙za/τ)exp(zi⋅zj(i)/τ)(1)
其中:
i
∈
I
≡
{
1
…
2
N
}
,
A
(
i
)
≡
I
∖
{
i
}
,
z
i
i \in I ≡ \{1\dots2N\},A(i) ≡ I\setminus\{i\},z_i
i∈I≡{1…2N},A(i)≡I∖{i},zi是样本的向量表示,
τ
\tau
τ是温度系数, 其中
j
(
i
)
j(i)
j(i)是样本
i
i
i对应的正样本,其余
2
N
−
2
2N - 2
2N−2个样本都是负样本。从损失函数的设计来看,我们可以很明显看出自监督对比学习无法处理样本带标签属于容一个类的情况,因为这时候自监督对比学习会认为他们是负样本。
Contrastive Learning Loss
我们接着来看监督对比学习的损失函数形式:
L
o
u
t
s
u
p
=
∑
i
∈
I
L
o
u
t
,
i
s
u
p
=
∑
i
∈
I
−
1
∣
P
(
i
)
∣
∑
p
∈
P
(
i
)
log
exp
(
z
i
∙
z
p
/
τ
)
∑
a
∈
A
(
i
)
exp
(
z
i
∙
z
a
/
τ
)
(
2
)
L
i
n
s
u
p
=
∑
i
∈
I
L
i
n
,
i
s
u
p
=
∑
i
∈
I
−
log
{
1
∣
P
(
i
)
∣
∑
p
∈
P
(
i
)
exp
(
z
i
∙
z
p
/
τ
)
∑
a
∈
A
(
i
)
exp
(
z
i
∙
z
a
/
τ
)
}
(
3
)
\begin{aligned}\mathcal{L}_{out}^{sup}&=\sum_{i\in I}\mathcal{L}_{out,i}^{sup}=\sum_{i\in I}\frac{-1}{|P(i)|}\sum_{p\in P(i)}\log\frac{\exp{(z_i\bullet z_p/\tau)}}{\sum_{a\in A(i)}\exp{(z_i\bullet z_a/\tau)}}&(2)\\\mathcal{L}_{in}^{sup}&=\sum_{i\in I}\mathcal{L}_{in,i}^{sup}=\sum_{i\in I}-\log\left\{\frac{1}{|P(i)|}\sum_{p\in P(i)}\frac{\exp{(z_i\bullet z_p/\tau)}}{\sum_{a\in A(i)}\exp{(z_i\bullet z_a/\tau)}}\right\}&(3)\end{aligned}
LoutsupLinsup=i∈I∑Lout,isup=i∈I∑∣P(i)∣−1p∈P(i)∑log∑a∈A(i)exp(zi∙za/τ)exp(zi∙zp/τ)=i∈I∑Lin,isup=i∈I∑−log⎩
⎨
⎧∣P(i)∣1p∈P(i)∑∑a∈A(i)exp(zi∙za/τ)exp(zi∙zp/τ)⎭
⎬
⎫(2)(3)
其中:
P
(
i
)
≡
{
p
∈
A
(
i
)
:
y
~
p
=
y
~
i
}
,
P
(
i
)
P(i)\equiv\{p\in A(i):\tilde{\boldsymbol{y}}_p=\tilde{\boldsymbol{y}}_i\},P(i)
P(i)≡{p∈A(i):y~p=y~i},P(i)是正样本的集合,表示具有同一标签的样本。
∣
P
(
i
)
∣
|P(i)|
∣P(i)∣是集合的基。其余的符号含义同上面的自监督对比学习。
上面的两个损失函数代表监督对比学习的两种形式,区别在于
1
∣
P
(
i
)
∣
\frac{1}{|P(i)|}
∣P(i)∣1项在内还是在外。
由于Jensen’s 不等式,可以证明
L
o
u
t
s
u
p
⩾
L
i
n
s
u
p
\mathcal{L}_{out}^{sup} \geqslant \mathcal{L}_{in}^{sup}
Loutsup⩾Linsup。这潜在表示了前者是后者性能的上限。
代码实现
import torch
import torch.nn as nn
class SupConLoss(nn.Module):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR"""
def __init__(self, temperature=0.07, contrast_mode='all',base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
def forward(self, features, labels=None, mask=None):
mask = torch.eq(labels, labels.T)
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),0)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss
如果想要了解更多的细节,最好自己去阅读原论文。