Contrastive Learning in NLP
我看到最早关于对比学习的论文是在SimCSE: Simple Contrastive Learning of Sentence Embeddings
这篇论文中,其中讲到了两种对比学习方法:无监督的对比学习和有监督的对比学习。
对比学习的学习目标可以表示为:
l
i
=
−
l
o
g
(
e
s
i
m
(
h
i
,
h
i
+
)
/
τ
∑
j
=
1
N
e
s
i
m
(
h
i
,
h
j
)
/
τ
)
\mathscr{l}_i = - log(\frac{\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_i^+)/\tau}}{\sum^N_{j=1}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_j)/\tau} })
li=−log(∑j=1Nesim(hi,hj)/τesim(hi,hi+)/τ)
其中
-
N N N表示batch size的大小, h i + \mathbf{h}_i^+ hi+表示为与 h i \mathbf{h}_i hi相同/相似的样本向量表示。
-
s i m ( h i , h j ) sim(\mathbf{h}_i,\mathbf{h}_j) sim(hi,hj)表示两个向量之间的度量,一般是余弦相似性或inner dot。
这里插一嘴,余弦相似性的计算可以之间调用pytorch的方法,即
F.cosine_similarity
来计算,当然也是可以用自己手写一个预选相似性from torch.nn import functional as F import torch batch_size = 4 dim=3 a = torch.randn(batch_size, dim) b = torch.randn(batch_size, dim) result = F.cosine_similarity(a, b) print(result) print(torch.matmul(a, b.T)/torch.norm(a, dim=1).view(-1, 1)/torch.norm(b, dim=1).view(1, -1)) print(torch.matmul(F.normalize(a), F.normalize(b).T))
输出结果为:
tensor([-0.0526, -0.3916, 0.0200, 0.2976]) tensor([[-0.0526, -0.5926, -0.6900, -0.9460], [-0.3521, -0.3916, -0.1019, -0.4261], [-0.2914, -0.3798, 0.0200, -0.2108], [-0.9280, 0.9054, 0.8497, 0.2976]]) tensor([[-0.0526, -0.5926, -0.6900, -0.9460], [-0.3521, -0.3916, -0.1019, -0.4261], [-0.2914, -0.3798, 0.0200, -0.2108], [-0.9280, 0.9054, 0.8497, 0.2976]])
而我们需要的是下面的l两个结果,这里可以看到不同样本之间的相似性关系。
Note:代码中是使用的F.normalize来归一化的。
对于无监督的对比学习来说,其实现的方法很简单,就是对同一个文本经过两次Dropout Layer就算是一对正样本对了。这样就可以得到 h i \mathbf{h}_i hi和 h i + \mathbf{h}_i^+ hi+了。
那么对于有监督的对比学习来说呢,原文是基于一个蕴含任务(NLI,natural language inference)来讨论的,即与样本
x
i
x_i
xi所对应的正样本
x
i
+
x_i^+
xi+在数据集中对应的标签是语义蕴含,反之从语义上来说是矛盾的,那么就是负样本
x
i
−
x_i^-
xi−。那么用公式就表示为
L
=
−
∑
i
=
1
N
l
o
g
(
e
s
i
m
(
h
i
,
h
i
+
)
/
τ
∑
j
=
1
N
e
s
i
m
(
h
i
,
h
i
+
)
/
τ
+
e
s
i
m
(
h
i
,
h
j
)
/
τ
)
\mathcal{L} = - \sum^N_{i=1} log(\frac{\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_i^+)/\tau}}{\sum^N_{j=1}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_i^+)/\tau}+\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_j)/\tau}})
L=−i=1∑Nlog(∑j=1Nesim(hi,hi+)/τ+esim(hi,hj)/τesim(hi,hi+)/τ)
Supervised Contrastive Learning
对于Supervised Contrastive Learning来说,我们直观上会认为对比学习的分子应该是相同或者是相近的的样本
那么
h
i
\mathbf{h}_i
hi和
h
i
+
\mathbf{h}_i^+
hi+所对应的标签都是
y
i
y_i
yi。因此在NeurIPS2020上的一篇名为Supervised contrastive learning
[论文 | 代码]的论文提出了SupCon loss。
最近在看论文,发现已有将监督式的对比学习加入到NLP的预训练过程中,这里我们以Learning Implicit Sentiment in Aspect-based Sentiment Analysis with Supervised Contrastive Pre-Training
这篇论文|代码中所讨论并使用的Supervised Contrastive Learning为例进行说明。
需要说明的是,公式的表达可能会与原文的表达有出入,除非特殊说明,否则一般情况下就是符号的运用不同而已。
那么对于一个batch size的数据来说,我们重新定义一个新的对比学习的优化目标,用公式就可以表示为:
L
s
u
p
=
−
∑
i
N
1
∣
P
i
∣
∑
p
∈
P
i
P
i
l
o
g
(
e
s
i
m
(
h
i
,
h
p
)
/
τ
∑
l
≠
i
N
e
s
i
m
(
h
i
,
h
j
)
/
τ
)
\mathcal{L}^{sup} = -\sum^N_{i} \frac{1}{|P_i|} \sum^{P_i}_{p \in P_i} log(\frac{\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_p)/\tau}}{\sum^{N}_{l\neq i}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_j)/\tau} })
Lsup=−i∑N∣Pi∣1p∈Pi∑Pilog(∑l=iNesim(hi,hj)/τesim(hi,hp)/τ)
其中,
P
i
=
{
p
∣
y
p
=
y
i
,
p
≠
i
}
P_i=\{p|y_p=y_i, p\neq i\}
Pi={p∣yp=yi,p=i}表示为与样本
x
i
x_i
xi类别相同的其他样本,
∣
⋅
∣
|·|
∣⋅∣表示为数量,
N
N
N表示batch size的数量。
下面是代码实现:
class ConLoss(nn.Module):
"""
基于https://github.com/HobbitLong/SupContrast/blob/master/losses.py
Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
但是其也是可以扩展到无监督的对比学习
It also supports the unsupervised contrastive loss in SimCLR
"""
def __init__(self, temperature=0.07,
base_temperature=0.07):
super(ConLoss, self).__init__()
self.temperature = temperature
self.base_temperature = base_temperature
def forward(self, features, labels=None, mask=None):
"""
Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
原版代码:
https://github.com/HobbitLong/SupContrast/blob/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/losses.py#L11
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
NLP中的使用方法:https://github.com/Tribleave/SCAPT-ABSA/blob/6f7f89a131127f262a8d1fd2774e5a96b58e7193/train/trainer/pretrain.py#L209
normed_cls_hidden = F.normalize(cls_hidden, dim=-1)
similar_loss = similar_criterion(normed_cls_hidden.unsqueeze(1), labels=labels)
"""
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)
# 根据mask或者label来生成对比学习需要的标记mask
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None: #这时候就是无监督的对比学习
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)
anchor_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
anchor_feature = contrast_feature
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
# log( exp(similarity) / sum(exp(similarity)) ) = similarity - log(sum(exp(similarity)))
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-30)
# compute mean of log-likelihood over positive
# modified to handle edge cases when there is no positive pair
# for an anchor point.
# Edge case e.g.:-
# features of shape: [4,1,...]
# labels: [0,1,1,2]
# loss before mean: [nan, ..., ..., nan]
mask_pos_pairs = mask.sum(1)
mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss
而上述公式在Supervised contrastive learning
被称作
L
o
u
t
s
u
p
\mathcal{L}^{sup}_{out}
Loutsup,即类别的数量在
log
\log
log的外面因此也被称作是out。反之,如果数量在
log
\log
log里面,那么就被称作是
L
i
n
s
u
p
\mathcal{L}^{sup}_{in}
Linsup。这里就不展开表示了,说一说论文中对这两种计算方法的对比说明:
-
损失大小关系及性能表现:根据 Jensen’s Inequality,由于对数函数是凹函数,所以 L o u t s u p ≤ L i n s u p \mathcal{L}^{sup}_{out} \leq \mathcal{L}^{sup}_{in} Loutsup≤Linsup。然而,在实际实验中, L o u t s u p \mathcal{L}^{sup}_{out} Loutsup在 ResNet - 50 架构上的 ImageNet 数据集上取得了显著更高的性能( L o u t s u p \mathcal{L}^{sup}_{out} Loutsup的 Top - 1 准确率为 78.7%,而 L i n s u p \mathcal{L}^{sup}_{in} Linsup为 67.4%)。这表明不能简单地根据不等式判断哪个公式更优。
-
梯度结构差异:对于 L i s u p \mathcal{L}^{sup}_{i} Lisup来说,其关于 h i \mathbf{h}_i hi的梯度表示为
∂ L i s u p ∂ h i = 1 τ { ∑ p ∈ P i h p ( P i ; p − X i ; p ) + ∑ n ∈ N i h n P i ; n } \frac{\partial \mathcal{L}^{sup}_{i}}{\partial \mathbf{h}_i}=\frac{1}{\tau}\{\sum_{p \in P_i} \mathbf{h}_p (P_{i;p}-X_{i;p})+\sum_{n\in N_i} \mathbf{h}_nP_{i;n}\} ∂hi∂Lisup=τ1{p∈Pi∑hp(Pi;p−Xi;p)+n∈Ni∑hnPi;n}
其中 N i = { n ∣ y n ≠ y i , i ≠ n } N_i=\{n|y_n \neq y_i,i \neq n\} Ni={n∣yn=yi,i=n}表示为与样本 x i x_i xi类别不同的其他样本,同时还不包含样本 x i x_i xi,而 P i ; p = e s i m ( h i , h p ) / τ ∑ l ≠ i N e s i m ( h i , h j ) / τ P_{i;p}=\frac{\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_p)/\tau}}{\sum^{N}_{l\neq i}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_j)/\tau} } Pi;p=∑l=iNesim(hi,hj)/τesim(hi,hp)/τ。换句话说, P i ∪ N i = { j ∈ N ∣ j ≠ i } P_i \cup N_i = \{j\in N|j \neq i\} Pi∪Ni={j∈N∣j=i}。这里的 N N N表示为Batch Size,而加了下标的表示negative。
那么不同点在于 X i ; p X_{i;p} Xi;p了。
X i ; p = { e s i m ( h i , h p ) / τ ∑ p ′ ∈ P i ; p e s i m ( h i , h p ′ ) / τ if L i s u p = L i n ; i s u p 1 ∣ P i ∣ if L i s u p = L o u t ; i s u p X_{i;p}= \begin{cases} \frac{\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_p)/\tau}}{\sum_{p\prime \in P_{i;p}}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_{p\prime })/\tau} }\quad & \text{if }\ \mathcal{L}^{sup}_{i}=\mathcal{L}^{sup}_{in;i}\\ \frac{1}{|P_i|}\quad & \text{if}\ \mathcal{L}^{sup}_{i}=\mathcal{L}^{sup}_{out;i} \end{cases} Xi;p=⎩ ⎨ ⎧∑p′∈Pi;pesim(hi,hp′)/τesim(hi,hp)/τ∣Pi∣1if Lisup=Lin;isupif Lisup=Lout;isup
那么这里有一种情况,即为当 h p \mathbf{h}_p hp设定为所有positive representation vector的平均值的话,那么 L o u t s u p = L i n s u p \mathcal{L}^{sup}_{out} = \mathcal{L}^{sup}_{in} Loutsup=Linsup
X
i
;
p
∣
h
p
=
h
p
ˉ
=
e
s
i
m
(
h
i
,
h
p
)
ˉ
/
τ
∑
p
′
∈
P
i
;
p
e
s
i
m
(
h
i
,
h
p
′
ˉ
)
/
τ
=
e
s
i
m
(
h
i
,
h
p
)
ˉ
/
τ
∣
P
i
∣
⋅
e
s
i
m
(
h
i
,
h
p
ˉ
)
/
τ
=
1
∣
P
i
∣
X_{i;p}|_{\mathbf{h}_p=\bar{\mathbf{h}_p}}=\frac{\mathscr{e}^{sim(\mathbf{h}_i,\bar{\mathbf{h}_p)}/\tau}}{\sum_{p\prime \in P_{i;p}}\mathscr{e}^{sim(\mathbf{h}_i,\bar{\mathbf{h}_{p\prime }})/\tau} }=\frac{\mathscr{e}^{sim(\mathbf{h}_i,\bar{\mathbf{h}_p)}/\tau}}{|P_i|\cdot \mathscr{e}^{sim(\mathbf{h}_i,\bar{\mathbf{h}_{p}})/\tau} }= \frac{1}{|P_i|}
Xi;p∣hp=hpˉ=∑p′∈Pi;pesim(hi,hp′ˉ)/τesim(hi,hp)ˉ/τ=∣Pi∣⋅esim(hi,hpˉ)/τesim(hi,hp)ˉ/τ=∣Pi∣1
即如果是均值,那么exp的表达式是相同的,因此分母就变成了
∣
P
i
∣
|P_i|
∣Pi∣个相同值了。
在Supervised contrastive learning论文中作者提到"From the form of ∂ L i s u p ∂ h i \frac{∂\mathcal{L}^{sup}_i}{∂\mathbf{h}_i} ∂hi∂Lisup , we conclude that the stabilization due to using the mean of positives benefits training. Throughout the rest of the paper, we consider only $\mathcal{L}^{sup}_{out} $."
Decoupled Contrastive Learning
这个Decoupled Contrastive Learning 是我偶然间看到的一个方法,这篇论文我看到最终发表在ECCV 2022中,感兴趣的可以去看看原文(但是在NeurIPS 2021中最终给出了Rejected,其他的reviewer给的也可以呀)。然后在AAAI 2024中就有一篇名为Decoupled Contrastive Learning for Long-Tailed Recognition
使用了Decoupled Contrastive Learning ,其在论文中提出了Decoupled Supervised Contrastive Loss(DSCL)。我就看了看其公开的代码,并结合文中的表述发现DSCL就是在原本的SCL的基础上添加了一个权重而已,毕竟原文中也说到:**The proposed DSCL is a generalization of SCL in both balanced setting and imbalanced setting. If the dataset is balanced,DSCL is the same as SCL by setting
α
=
1
/
(
∣
P
i
∣
+
1
)
\alpha = 1/(|P_i|+1)
α=1/(∣Pi∣+1) **。
为了与上文中讲到的SCL形成对比,表述方便,这里我就对Decoupled Contrastive Learning for Long-Tailed Recognition
论文中的公式8和9进行重新的表述,尽量做到符号的统一。
L
s
u
p
=
−
∑
i
N
1
∣
P
i
∣
∑
p
∈
P
i
P
i
l
o
g
(
w
p
⋅
e
s
i
m
(
h
i
,
h
p
)
/
τ
∑
l
≠
i
N
e
s
i
m
(
h
i
,
h
j
)
/
τ
)
w
p
=
(
1
−
α
)
(
∣
P
i
∣
+
1
)
∣
P
i
∣
\mathcal{L}^{sup} = -\sum^N_{i} \frac{1}{|P_i|} \sum^{P_i}_{p \in P_i} log(\frac{w_p \cdot \mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_p)/\tau}}{\sum^{N}_{l\neq i}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_j)/\tau} }) \\ w_p =(1-\alpha) \frac{(|P_i|+1)}{|P_i|}
Lsup=−i∑N∣Pi∣1p∈Pi∑Pilog(∑l=iNesim(hi,hj)/τwp⋅esim(hi,hp)/τ)wp=(1−α)∣Pi∣(∣Pi∣+1)
这里要说明一点,原文是说 P i = { p ∈ M ∣ y p = y i } P_i=\{ p\in M | y_p=y_i\} Pi={p∈M∣yp=yi} , 而这个 M M M是这样表述的"We use M to
denote a set of sample features that can be acquired by the memory queue (He et al. 2020)," 我没有在He等人的原文中找到这个M的表述,那么就不能确定M是不是包含了样本 x i x_i xi。
但是看
Decoupled Contrastive Learning for Long-Tailed Recognition
论文中代码,确实是将 x i x_i xi给去掉了:# mask-out self-contrast cases logits_mask = torch.scatter( torch.ones_like(mask), 1, torch.arange(batch_size).view(-1, 1).to(device), 0 )
那么为了表示方便,就默认是不包含吧,与上下文一致。
这样的话,就可以复用基于SCL的代码了,具体表示为:
class ConLoss(nn.Module):
"""
基于https://github.com/HobbitLong/SupContrast/blob/master/losses.py
Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
但是其也是可以扩展到无监督的对比学习
It also supports the unsupervised contrastive loss in SimCLR
"""
def __init__(self, temperature=0.07,
base_temperature=0.07,
decoupled_mode=False,
weighted_alpha=0.1
):
super(ConLoss, self).__init__()
self.temperature = temperature
self.decoupled_mode = decoupled_mode
self.base_temperature = base_temperature
self.weighted_alpha = weighted_alpha
def forward(self, features, labels=None, mask=None):
"""
Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
"""
device = features.device
features = F.normalize(features,dim=-1)
# 根据mask或者label来生成对比学习需要的标记mask
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None: #这时候就是无监督的对比学习
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)
# # compute logits
anchor_dot_contrast = torch.div(
torch.matmul(features, features.T),
self.temperature) # (batch_size, batch_size)
mask[np.arange(batch_size), np.arange(batch_size)] = 0 # batch_size * batch_size 对角线为0
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# compute log_prob
logits_mask = torch.ones_like(mask, device=device) - torch.eye(batch_size, device=device)
exp_logits = (torch.exp(logits) * logits_mask).sum(1, keepdim=True)
try:
assert torch.all(exp_logits > 0)
except:
print('exp_logits:',exp_logits)
print('logits:',logits)
raise ValueError('exp_logits should be greater than 0')
if torch.isnan(logits).any() or torch.isinf(logits).any():
raise ValueError("logits contains NaN or Inf values")
# compute mean of log-likelihood over positive 公式3
if self.decoupled_mode:
'''
Decoupled Contrastive Learning
https://github.com/SY-Xuan/DSCL/blob/72c823bacabc7e09a3656e9047403681eb0ef5c2/dscl/DSCL.py#L208
'''
# 公式9的第二部分的计算
class_weighted = torch.ones_like(mask) * (1.0-self.weighted_alpha) * mask.sum(dim=1, keepdim=True) # 分子部分
class_weighted = torch.div(class_weighted, torch.where(mask.sum(dim=1, keepdim=True)>1, mask.sum(dim=1, keepdim=True)-1, 1.0)) # 分母部分
# 在公式9中,如果t=i,那么就不需要考虑公式9的第一部分的计算了,因此下面的就用不到了
# class_weighted = class_weighted.scatter(1, torch.arange(batch_size).view(-1, 1).to(device), self.weighted_alpha*mask.sum(dim=1))
logits = logits * class_weighted
# loss
log_prob = logits - torch.log(exp_logits + 1e-30)
mean_log_prob_pos = torch.div((log_prob * mask).sum(1), torch.where(mask.sum(1)!=0, mask.sum(1), 1e-30))
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos.mean()
return loss