有监督特征对比学习pytorch

https://github.com/GuillaumeErhard/Supervised_contrastive_loss_pytorch/blob/main/loss/spc.py

import torch
import torch.nn as nn

from math import log
class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        """
        Implementation of the loss described in the paper Supervised Contrastive Learning :
        https://arxiv.org/abs/2004.11362
        :param temperature: int
        """
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, projections, targets):
        """
        :param projections: torch.Tensor, shape [batch_size, projection_dim]
        :param targets: torch.Tensor, shape [batch_size]
        :return: torch.Tensor, scalar
        """
        device = torch.device("cuda") if projections.is_cuda else torch.device("cpu")

        dot_product_tempered = torch.mm(projections, projections.T) / self.temperature
        # Minus max for numerical stability with exponential. Same done in cross entropy. Epsilon added to avoid log(0)
        exp_dot_tempered = (
            torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
        )

        mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets).to(device)
        mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(device)
        mask_combined = mask_similar_class * mask_anchor_out
        cardinality_per_samples = torch.sum(mask_combined, dim=1)

        log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
        supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / cardinality_per_samples
        supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)

        return supervised_contrastive_loss


关于对比损失

  无监督对比损失,通常视数据增强后的图像与原图像互为正例。而对于有监督对比损失来说,可以将同一batch中标签相同的视为正例,与它不同标签的视为负例。对比学习能够使得同类更近,不同类更远。有监督对比损失公式如下。

有监督对比损失数学公式

Pytorch实现有监督对比损失

  话不多说,直接看代码。为了更好的说明有监督对比损失的整个实现过程,以下代码没有经过系统整理,从一个例子,一步一步地计算出损失。若是理解了每一步,那系统整理应该没什么问题。

1.通过cos计算相似度

import torch
import torch.nn.functional as F
T = 0.5  #温度参数T
label = torch.tensor([1,0,1,0,1])
n = label.shape[0]  # batch
#假设我们的输入是5 * 3  5是batch,3是句向量
representations = torch.tensor([[1, 2, 3],[1.2, 2.2, 3.3],
                                [1.3, 2.3, 4.3],[1.5, 2.6, 3.9],
                                [5.1, 2.1, 3.4]])
#这步得到它的相似度矩阵
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
#这步得到它的label矩阵,相同label的位置为1
similarity_matrix = torch.exp(similarity_matrix/T)
print('similarity_matrix is *****')
print(similarity_matrix)

  结果

similarity_matrix is *****
tensor([[7.3891, 7.3851, 7.3241, 7.3777, 4.9964],
        [7.3851, 7.3891, 7.3172, 7.3872, 5.1341],
        [7.3241, 7.3172, 7.3891, 7.3079, 4.9291],
        [7.3777, 7.3872, 7.3079, 7.3891, 5.2278],
        [4.9964, 5.1341, 4.9291, 5.2278, 7.3891]])

2.创建各种mask

mask = torch.ones_like(similarity_matrix) * (label.expand(n, n).eq(label.expand(n, n).t())) - torch.eye(n, n )
#这步得到它的不同类的矩阵,不同类的位置为1
mask_no_sim = torch.ones_like(mask) - mask
#这步产生一个对角线全为0的,其他位置为1的矩阵
mask_dui_jiao_0 = torch.ones(n ,n) - torch.eye(n, n )
#这步给相似度矩阵求exp,并且除以温度参数T
print('mask is *****')
print(mask)

print('mask_no_sim is *****')
print(mask_no_sim)

print('mask_dui_jiao_0 is *****')
print(mask_dui_jiao_0)

结果为

mask is *****
tensor([[0., 0., 1., 0., 1.],
        [0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [1., 0., 1., 0., 0.]])
mask_no_sim is *****
tensor([[1., 1., 0., 1., 0.],
        [1., 1., 1., 0., 1.],
        [0., 1., 1., 1., 0.],
        [1., 0., 1., 1., 1.],
        [0., 1., 0., 1., 1.]])
mask_dui_jiao_0 is *****
tensor([[0., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1.],
        [1., 1., 0., 1., 1.],
        [1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 0.]])

3.相应创建各种矩阵

#这步将相似度矩阵的对角线上的值全置0,因为对比损失不需要自己与自己的相似度
similarity_matrix = similarity_matrix*mask_dui_jiao_0
print('similarity_matrix is *******')
print(similarity_matrix)

#这步产生了相同类别的相似度矩阵,标签相同的位置保存它们的相似度,其他位置都是0,对角线上也为0
sim = mask*similarity_matrix
print('sim is ')
print(sim)

#用原先的对角线为0的相似度矩阵减去相同类别的相似度矩阵就是不同类别的相似度矩阵
no_sim = similarity_matrix - sim
print('no_sim is ')
print(no_sim)
#把不同类别的相似度矩阵按行求和,得到的是对比损失的分母(还差一个与分子相同的那个相似度,后面会加上)
no_sim_sum = torch.sum(no_sim , dim=1)

结果为

similarity_matrix is *******
tensor([[0.0000, 7.3851, 7.3241, 7.3777, 4.9964],
        [7.3851, 0.0000, 7.3172, 7.3872, 5.1341],
        [7.3241, 7.3172, 0.0000, 7.3079, 4.9291],
        [7.3777, 7.3872, 7.3079, 0.0000, 5.2278],
        [4.9964, 5.1341, 4.9291, 5.2278, 0.0000]])
sim is 
tensor([[0.0000, 0.0000, 7.3241, 0.0000, 4.9964],
        [0.0000, 0.0000, 0.0000, 7.3872, 0.0000],
        [7.3241, 0.0000, 0.0000, 0.0000, 4.9291],
        [0.0000, 7.3872, 0.0000, 0.0000, 0.0000],
        [4.9964, 0.0000, 4.9291, 0.0000, 0.0000]])
no_sim is 
tensor([[0.0000, 7.3851, 0.0000, 7.3777, 0.0000],
        [7.3851, 0.0000, 7.3172, 0.0000, 5.1341],
        [0.0000, 7.3172, 0.0000, 7.3079, 0.0000],
        [7.3777, 0.0000, 7.3079, 0.0000, 5.2278],
        [0.0000, 5.1341, 0.0000, 5.2278, 0.0000]])

4.计算分母的矩阵

'''
将上面的矩阵扩展一下,再转置,加到sim(也就是相同标签的矩阵上),然后再把sim矩阵与sim_num矩阵做除法。
至于为什么这么做,就是因为对比损失的分母存在一个同类别的相似度,就是分子的数据。做了除法之后,就能得到
每个标签相同的相似度与它不同标签的相似度的值,它们在一个矩阵(loss矩阵)中。
'''
no_sim_sum_expend = no_sim_sum.repeat(n, 1).T
print('no_sim_sum_expend is ')
print(no_sim_sum_expend)
sim_sum  = sim + no_sim_sum_expend

结果为

no_sim_sum_expend is 
tensor([[14.7628, 14.7628, 14.7628, 14.7628, 14.7628],
        [19.8363, 19.8363, 19.8363, 19.8363, 19.8363],
        [14.6251, 14.6251, 14.6251, 14.6251, 14.6251],
        [19.9134, 19.9134, 19.9134, 19.9134, 19.9134],
        [10.3618, 10.3618, 10.3618, 10.3618, 10.3618]])

5.计算对比loss

loss = torch.div(sim , sim_sum)
    '''
    由于loss矩阵中,存在0数值,那么在求-log的时候会出错。这时候,我们就将loss矩阵里面为0的地方
    全部加上1,然后再去求loss矩阵的值,那么-log1 = 0 ,就是我们想要的。
    '''
    loss = mask_no_sim + loss + torch.eye(n, n )
    #接下来就是算一个批次中的loss了
    loss = -torch.log(loss)  #求-log
    #loss = torch.sum(torch.sum(loss, dim=1) )/(2*n)  #将所有数据都加起来除以2n
    #print(loss)  #0.9821
    #最后一步也可以写为---建议用这个, (len(torch.nonzero(loss)))表示一个批次中样本对个数的一半
    loss = torch.sum(torch.sum(loss, dim=1)) / (len(torch.nonzero(loss)))
    

6.完整的计算

def sup_constrive(representations, label,T):
    n = label.shape[0]
    similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
    #这步得到它的label矩阵,相同label的位置为1
    mask = torch.ones_like(similarity_matrix) * (label.expand(n, n).eq(label.expand(n, n).t())) - torch.eye(n, n)
    
    #这步得到它的不同类的矩阵,不同类的位置为1
    mask_no_sim = torch.ones_like(mask) - mask
    #这步产生一个对角线全为0的,其他位置为1的矩阵
    mask_dui_jiao_0 = torch.ones(n ,n) - torch.eye(n, n )
    #这步给相似度矩阵求exp,并且除以温度参数T
    similarity_matrix = torch.exp(similarity_matrix/T)
    #这步将相似度矩阵的对角线上的值全置0,因为对比损失不需要自己与自己的相似度
    similarity_matrix = similarity_matrix*mask_dui_jiao_0
    #这步产生了相同类别的相似度矩阵,标签相同的位置保存它们的相似度,其他位置都是0,对角线上也为0
    sim = mask*similarity_matrix
    #用原先的对角线为0的相似度矩阵减去相同类别的相似度矩阵就是不同类别的相似度矩阵
    no_sim = similarity_matrix - sim
    #把不同类别的相似度矩阵按行求和,得到的是对比损失的分母(还差一个与分子相同的那个相似度,后面会加上)
    no_sim_sum = torch.sum(no_sim , dim=1)
    '''
    将上面的矩阵扩展一下,再转置,加到sim(也就是相同标签的矩阵上),然后再把sim矩阵与sim_num矩阵做除法。
    至于为什么这么做,就是因为对比损失的分母存在一个同类别的相似度,就是分子的数据。做了除法之后,就能得到
    每个标签相同的相似度与它不同标签的相似度的值,它们在一个矩阵(loss矩阵)中。
    '''
    no_sim_sum_expend = no_sim_sum.repeat(n, 1).T
    sim_sum  = sim + no_sim_sum_expend
    loss = torch.div(sim , sim_sum)
    '''
    由于loss矩阵中,存在0数值,那么在求-log的时候会出错。这时候,我们就将loss矩阵里面为0的地方
    全部加上1,然后再去求loss矩阵的值,那么-log1 = 0 ,就是我们想要的。
    '''
    loss = mask_no_sim + loss + torch.eye(n, n )
    #接下来就是算一个批次中的loss了
    loss = -torch.log(loss)  #求-log
    #loss = torch.sum(torch.sum(loss, dim=1) )/(2*n)  #将所有数据都加起来除以2n
    #print(loss)  #0.9821
    #最后一步也可以写为---建议用这个, (len(torch.nonzero(loss)))表示一个批次中样本对个数的一半
    loss = torch.sum(torch.sum(loss, dim=1)) / (len(torch.nonzero(loss)))
    
    return loss

x = torch.rand(8,64)
label = torch.tensor([0,2,3,2,1,1,3,1])
sup_constrive(x, label,T=0.1)

大致实现过程就是这样,如果有什么问题可以随时提出。或者有什么更好的实现方法,也欢迎共享。若你要使用该损失发文章,请引用:

  “Chen, L., Wang, F., Yang, R. et al. Representation learning from noisy user-tagged data for sentiment classification. Int. J. Mach. Learn. & Cyber. (2022). https://doi.org/10.1007/s13042-022-01622-7

  • 4
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值