监督对比学习代码实现与分析(Supervised Contrastive Learning in NLP)

Contrastive Learning in NLP

我看到最早关于对比学习的论文是在SimCSE: Simple Contrastive Learning of Sentence Embeddings这篇论文中,其中讲到了两种对比学习方法:无监督的对比学习和有监督的对比学习。

对比学习的学习目标可以表示为:
l i = − l o g ( e s i m ( h i , h i + ) / τ ∑ j = 1 N e s i m ( h i , h j ) / τ ) \mathscr{l}_i = - log(\frac{\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_i^+)/\tau}}{\sum^N_{j=1}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_j)/\tau} }) li=log(j=1Nesim(hi,hj)/τesim(hi,hi+)/τ)
其中

  • N N N表示batch size的大小, h i + \mathbf{h}_i^+ hi+表示为与 h i \mathbf{h}_i hi相同/相似的样本向量表示。

  • s i m ( h i , h j ) sim(\mathbf{h}_i,\mathbf{h}_j) sim(hi,hj)表示两个向量之间的度量,一般是余弦相似性或inner dot。

    这里插一嘴,余弦相似性的计算可以之间调用pytorch的方法,即F.cosine_similarity来计算,当然也是可以用自己手写一个预选相似性

    from torch.nn import functional as F
    import torch
    batch_size = 4
    dim=3
    a = torch.randn(batch_size, dim)
    b = torch.randn(batch_size, dim)
    result = F.cosine_similarity(a, b)
    print(result)
    print(torch.matmul(a, b.T)/torch.norm(a, dim=1).view(-1, 1)/torch.norm(b, dim=1).view(1, -1))
    print(torch.matmul(F.normalize(a), F.normalize(b).T))
    

    输出结果为:

    tensor([-0.0526, -0.3916,  0.0200,  0.2976])
    tensor([[-0.0526, -0.5926, -0.6900, -0.9460],
            [-0.3521, -0.3916, -0.1019, -0.4261],
            [-0.2914, -0.3798,  0.0200, -0.2108],
            [-0.9280,  0.9054,  0.8497,  0.2976]])
    tensor([[-0.0526, -0.5926, -0.6900, -0.9460],
            [-0.3521, -0.3916, -0.1019, -0.4261],
            [-0.2914, -0.3798,  0.0200, -0.2108],
            [-0.9280,  0.9054,  0.8497,  0.2976]])
    

    而我们需要的是下面的l两个结果,这里可以看到不同样本之间的相似性关系。

    Note:代码中是使用的F.normalize来归一化的

对于无监督的对比学习来说,其实现的方法很简单,就是对同一个文本经过两次Dropout Layer就算是一对正样本对了。这样就可以得到 h i \mathbf{h}_i hi h i + \mathbf{h}_i^+ hi+了。

那么对于有监督的对比学习来说呢,原文是基于一个蕴含任务(NLI,natural language inference)来讨论的,即与样本 x i x_i xi所对应的正样本 x i + x_i^+ xi+在数据集中对应的标签是语义蕴含,反之从语义上来说是矛盾的,那么就是负样本 x i − x_i^- xi。那么用公式就表示为
L = − ∑ i = 1 N l o g ( e s i m ( h i , h i + ) / τ ∑ j = 1 N e s i m ( h i , h i + ) / τ + e s i m ( h i , h j ) / τ ) \mathcal{L} = - \sum^N_{i=1} log(\frac{\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_i^+)/\tau}}{\sum^N_{j=1}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_i^+)/\tau}+\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_j)/\tau}}) L=i=1Nlog(j=1Nesim(hi,hi+)/τ+esim(hi,hj)/τesim(hi,hi+)/τ)

Supervised Contrastive Learning

对于Supervised Contrastive Learning来说,我们直观上会认为对比学习的分子应该是相同或者是相近的的样本

那么 h i \mathbf{h}_i hi h i + \mathbf{h}_i^+ hi+所对应的标签都是 y i y_i yi。因此在NeurIPS2020上的一篇名为Supervised contrastive learning[论文 | 代码]的论文提出了SupCon loss。

最近在看论文,发现已有将监督式的对比学习加入到NLP的预训练过程中,这里我们以Learning Implicit Sentiment in Aspect-based Sentiment Analysis with Supervised Contrastive Pre-Training这篇论文|代码中所讨论并使用的Supervised Contrastive Learning为例进行说明。

需要说明的是,公式的表达可能会与原文的表达有出入,除非特殊说明,否则一般情况下就是符号的运用不同而已。

那么对于一个batch size的数据来说,我们重新定义一个新的对比学习的优化目标,用公式就可以表示为:
L s u p = − ∑ i N 1 ∣ P i ∣ ∑ p ∈ P i P i l o g ( e s i m ( h i , h p ) / τ ∑ l ≠ i N e s i m ( h i , h j ) / τ ) \mathcal{L}^{sup} = -\sum^N_{i} \frac{1}{|P_i|} \sum^{P_i}_{p \in P_i} log(\frac{\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_p)/\tau}}{\sum^{N}_{l\neq i}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_j)/\tau} }) Lsup=iNPi1pPiPilog(l=iNesim(hi,hj)/τesim(hi,hp)/τ)
其中, P i = { p ∣ y p = y i , p ≠ i } P_i=\{p|y_p=y_i, p\neq i\} Pi={pyp=yi,p=i}表示为与样本 x i x_i xi类别相同的其他样本, ∣ ⋅ ∣ |·| 表示为数量, N N N表示batch size的数量。

下面是代码实现:

class ConLoss(nn.Module):
    """
    基于https://github.com/HobbitLong/SupContrast/blob/master/losses.py
    Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    但是其也是可以扩展到无监督的对比学习
    It also supports the unsupervised contrastive loss in SimCLR
    """
    def __init__(self, temperature=0.07, 
                 base_temperature=0.07):
        super(ConLoss, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """
        Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        原版代码:
        https://github.com/HobbitLong/SupContrast/blob/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/losses.py#L11
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        
        NLP中的使用方法:https://github.com/Tribleave/SCAPT-ABSA/blob/6f7f89a131127f262a8d1fd2774e5a96b58e7193/train/trainer/pretrain.py#L209
        normed_cls_hidden = F.normalize(cls_hidden, dim=-1)
        similar_loss = similar_criterion(normed_cls_hidden.unsqueeze(1), labels=labels)
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)


        # 根据mask或者label来生成对比学习需要的标记mask
        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None: #这时候就是无监督的对比学习
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)
    
        anchor_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        anchor_feature = contrast_feature

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        # log( exp(similarity) / sum(exp(similarity)) ) = similarity - log(sum(exp(similarity)))
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-30)

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point. 
        # Edge case e.g.:- 
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan]
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

而上述公式在Supervised contrastive learning被称作 L o u t s u p \mathcal{L}^{sup}_{out} Loutsup,即类别的数量在 log ⁡ \log log的外面因此也被称作是out。反之,如果数量在 log ⁡ \log log里面,那么就被称作是 L i n s u p \mathcal{L}^{sup}_{in} Linsup。这里就不展开表示了,说一说论文中对这两种计算方法的对比说明:

  • 损失大小关系及性能表现:根据 Jensen’s Inequality,由于对数函数是凹函数,所以 L o u t s u p ≤ L i n s u p \mathcal{L}^{sup}_{out} \leq \mathcal{L}^{sup}_{in} LoutsupLinsup。然而,在实际实验中, L o u t s u p \mathcal{L}^{sup}_{out} Loutsup在 ResNet - 50 架构上的 ImageNet 数据集上取得了显著更高的性能( L o u t s u p \mathcal{L}^{sup}_{out} Loutsup的 Top - 1 准确率为 78.7%,而 L i n s u p \mathcal{L}^{sup}_{in} Linsup为 67.4%)。这表明不能简单地根据不等式判断哪个公式更优。

  • 梯度结构差异:对于 L i s u p \mathcal{L}^{sup}_{i} Lisup来说,其关于 h i \mathbf{h}_i hi的梯度表示为
    ∂ L i s u p ∂ h i = 1 τ { ∑ p ∈ P i h p ( P i ; p − X i ; p ) + ∑ n ∈ N i h n P i ; n } \frac{\partial \mathcal{L}^{sup}_{i}}{\partial \mathbf{h}_i}=\frac{1}{\tau}\{\sum_{p \in P_i} \mathbf{h}_p (P_{i;p}-X_{i;p})+\sum_{n\in N_i} \mathbf{h}_nP_{i;n}\} hiLisup=τ1{pPihp(Pi;pXi;p)+nNihnPi;n}
    其中 N i = { n ∣ y n ≠ y i , i ≠ n } N_i=\{n|y_n \neq y_i,i \neq n\} Ni={nyn=yi,i=n}表示为与样本 x i x_i xi类别不同的其他样本,同时还不包含样本 x i x_i xi,而 P i ; p = e s i m ( h i , h p ) / τ ∑ l ≠ i N e s i m ( h i , h j ) / τ P_{i;p}=\frac{\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_p)/\tau}}{\sum^{N}_{l\neq i}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_j)/\tau} } Pi;p=l=iNesim(hi,hj)/τesim(hi,hp)/τ

    换句话说, P i ∪ N i = { j ∈ N ∣ j ≠ i } P_i \cup N_i = \{j\in N|j \neq i\} PiNi={jNj=i}。这里的 N N N表示为Batch Size,而加了下标的表示negative。

    那么不同点在于 X i ; p X_{i;p} Xi;p了。
    X i ; p = { e s i m ( h i , h p ) / τ ∑ p ′ ∈ P i ; p e s i m ( h i , h p ′ ) / τ if   L i s u p = L i n ; i s u p 1 ∣ P i ∣ if  L i s u p = L o u t ; i s u p X_{i;p}= \begin{cases} \frac{\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_p)/\tau}}{\sum_{p\prime \in P_{i;p}}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_{p\prime })/\tau} }\quad & \text{if }\ \mathcal{L}^{sup}_{i}=\mathcal{L}^{sup}_{in;i}\\ \frac{1}{|P_i|}\quad & \text{if}\ \mathcal{L}^{sup}_{i}=\mathcal{L}^{sup}_{out;i} \end{cases} Xi;p= pPi;pesim(hi,hp)/τesim(hi,hp)/τPi1if  Lisup=Lin;isupif Lisup=Lout;isup

那么这里有一种情况,即为当 h p \mathbf{h}_p hp设定为所有positive representation vector的平均值的话,那么 L o u t s u p = L i n s u p \mathcal{L}^{sup}_{out} = \mathcal{L}^{sup}_{in} Loutsup=Linsup

X i ; p ∣ h p = h p ˉ = e s i m ( h i , h p ) ˉ / τ ∑ p ′ ∈ P i ; p e s i m ( h i , h p ′ ˉ ) / τ = e s i m ( h i , h p ) ˉ / τ ∣ P i ∣ ⋅ e s i m ( h i , h p ˉ ) / τ = 1 ∣ P i ∣ X_{i;p}|_{\mathbf{h}_p=\bar{\mathbf{h}_p}}=\frac{\mathscr{e}^{sim(\mathbf{h}_i,\bar{\mathbf{h}_p)}/\tau}}{\sum_{p\prime \in P_{i;p}}\mathscr{e}^{sim(\mathbf{h}_i,\bar{\mathbf{h}_{p\prime }})/\tau} }=\frac{\mathscr{e}^{sim(\mathbf{h}_i,\bar{\mathbf{h}_p)}/\tau}}{|P_i|\cdot \mathscr{e}^{sim(\mathbf{h}_i,\bar{\mathbf{h}_{p}})/\tau} }= \frac{1}{|P_i|} Xi;php=hpˉ=pPi;pesim(hi,hpˉ)/τesim(hi,hp)ˉ/τ=Piesim(hi,hpˉ)/τesim(hi,hp)ˉ/τ=Pi1
即如果是均值,那么exp的表达式是相同的,因此分母就变成了 ∣ P i ∣ |P_i| Pi个相同值了。

在Supervised contrastive learning论文中作者提到"From the form of ∂ L i s u p ∂ h i \frac{∂\mathcal{L}^{sup}_i}{∂\mathbf{h}_i} hiLisup , we conclude that the stabilization due to using the mean of positives benefits training. Throughout the rest of the paper, we consider only $\mathcal{L}^{sup}_{out} $."

Decoupled Contrastive Learning

这个Decoupled Contrastive Learning 是我偶然间看到的一个方法,这篇论文我看到最终发表在ECCV 2022中,感兴趣的可以去看看原文(但是在NeurIPS 2021中最终给出了Rejected,其他的reviewer给的也可以呀)。然后在AAAI 2024中就有一篇名为Decoupled Contrastive Learning for Long-Tailed Recognition使用了Decoupled Contrastive Learning ,其在论文中提出了Decoupled Supervised Contrastive Loss(DSCL)。我就看了看其公开的代码,并结合文中的表述发现DSCL就是在原本的SCL的基础上添加了一个权重而已,毕竟原文中也说到:**The proposed DSCL is a generalization of SCL in both balanced setting and imbalanced setting. If the dataset is balanced,DSCL is the same as SCL by setting α = 1 / ( ∣ P i ∣ + 1 ) \alpha = 1/(|P_i|+1) α=1/(Pi+1) **。

为了与上文中讲到的SCL形成对比,表述方便,这里我就对Decoupled Contrastive Learning for Long-Tailed Recognition论文中的公式8和9进行重新的表述,尽量做到符号的统一。
L s u p = − ∑ i N 1 ∣ P i ∣ ∑ p ∈ P i P i l o g ( w p ⋅ e s i m ( h i , h p ) / τ ∑ l ≠ i N e s i m ( h i , h j ) / τ ) w p = ( 1 − α ) ( ∣ P i ∣ + 1 ) ∣ P i ∣ \mathcal{L}^{sup} = -\sum^N_{i} \frac{1}{|P_i|} \sum^{P_i}_{p \in P_i} log(\frac{w_p \cdot \mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_p)/\tau}}{\sum^{N}_{l\neq i}\mathscr{e}^{sim(\mathbf{h}_i,\mathbf{h}_j)/\tau} }) \\ w_p =(1-\alpha) \frac{(|P_i|+1)}{|P_i|} Lsup=iNPi1pPiPilog(l=iNesim(hi,hj)/τwpesim(hi,hp)/τ)wp=(1α)Pi(Pi+1)

这里要说明一点,原文是说 P i = { p ∈ M ∣ y p = y i } P_i=\{ p\in M | y_p=y_i\} Pi={pMyp=yi} , 而这个 M M M是这样表述的"We use M to

denote a set of sample features that can be acquired by the memory queue (He et al. 2020)," 我没有在He等人的原文中找到这个M的表述,那么就不能确定M是不是包含了样本 x i x_i xi

但是看Decoupled Contrastive Learning for Long-Tailed Recognition论文中代码,确实是将 x i x_i xi给去掉了:

# mask-out self-contrast cases
logits_mask = torch.scatter(
    torch.ones_like(mask),
    1,
    torch.arange(batch_size).view(-1, 1).to(device),
    0
)

那么为了表示方便,就默认是不包含吧,与上下文一致。

这样的话,就可以复用基于SCL的代码了,具体表示为:

class ConLoss(nn.Module):
    """
    基于https://github.com/HobbitLong/SupContrast/blob/master/losses.py
    Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.

    但是其也是可以扩展到无监督的对比学习
    It also supports the unsupervised contrastive loss in SimCLR
    """
    def __init__(self, temperature=0.07, 
                 base_temperature=0.07,
                 decoupled_mode=False,
                 weighted_alpha=0.1
                 ):
        super(ConLoss, self).__init__()
        self.temperature = temperature
        self.decoupled_mode = decoupled_mode
        self.base_temperature = base_temperature
        self.weighted_alpha = weighted_alpha

    def forward(self, features, labels=None, mask=None):
        """
        Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        """
        device = features.device
        features = F.normalize(features,dim=-1)

        # 根据mask或者label来生成对比学习需要的标记mask
        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None: #这时候就是无监督的对比学习
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        # # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.temperature) # (batch_size, batch_size)
        mask[np.arange(batch_size), np.arange(batch_size)] = 0 # batch_size * batch_size 对角线为0

        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        # compute log_prob 
        logits_mask = torch.ones_like(mask, device=device) - torch.eye(batch_size, device=device)
        exp_logits = (torch.exp(logits) * logits_mask).sum(1, keepdim=True)
        try:
            assert torch.all(exp_logits > 0)
        except:
            print('exp_logits:',exp_logits)
            print('logits:',logits)
            raise ValueError('exp_logits should be greater than 0')
        if torch.isnan(logits).any() or torch.isinf(logits).any():
            raise ValueError("logits contains NaN or Inf values")

        # compute mean of log-likelihood over positive 公式3
        if self.decoupled_mode:
            '''
            Decoupled Contrastive Learning 
            https://github.com/SY-Xuan/DSCL/blob/72c823bacabc7e09a3656e9047403681eb0ef5c2/dscl/DSCL.py#L208
            '''
            # 公式9的第二部分的计算
            class_weighted = torch.ones_like(mask) * (1.0-self.weighted_alpha) * mask.sum(dim=1, keepdim=True) # 分子部分
            class_weighted = torch.div(class_weighted, torch.where(mask.sum(dim=1, keepdim=True)>1, mask.sum(dim=1, keepdim=True)-1, 1.0)) # 分母部分
            # 在公式9中,如果t=i,那么就不需要考虑公式9的第一部分的计算了,因此下面的就用不到了
            # class_weighted = class_weighted.scatter(1, torch.arange(batch_size).view(-1, 1).to(device), self.weighted_alpha*mask.sum(dim=1))
            logits = logits * class_weighted

        # loss
        log_prob = logits - torch.log(exp_logits + 1e-30)
        mean_log_prob_pos = torch.div((log_prob * mask).sum(1), torch.where(mask.sum(1)!=0, mask.sum(1), 1e-30))
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos.mean()
        return loss
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值