Pytorch实现有监督对比学习损失函数
关于对比损失
无监督对比损失,通常视数据增强后的图像与原图像互为正例。而对于有监督对比损失来说,可以将同一batch中标签相同的视为正例,与它不同标签的视为负例。对比学习能够使得同类更近,不同类更远。有监督对比损失公式如下。
有监督对比损失数学公式
Pytorch实现有监督对比损失
话不多说,直接看代码。为了更好的说明有监督对比损失的整个实现过程,以下代码没有经过系统整理,从一个例子,一步一步地计算出损失。若是理解了每一步,那系统整理应该没什么问题。
import torch
import torch.nn.functional as F
T = 0.5 #温度参数T
label = torch.tensor([1,0,1,0,1])
n = label.shape[0] # batch
#假设我们的输入是5 * 3 5是batch,3是句向量
representations = torch.tensor([[1, 2, 3],[1.2, 2.2, 3.3],
[1.3, 2.3, 4.3],[1.5, 2.6, 3.9],
[5.1, 2.1, 3.4]])
#这步得到它的相似度矩阵
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
#这步得到它的label矩阵,相同label的位置为1
mask = torch.ones_like(similarity_matrix) * (label.expand(n, n).eq(label.expand(n, n).t()))
#这步得到它的不同类的矩阵,不同类的位置为1
mask_no_sim = torch.ones_like(mask) - mask
#这步产生一个对角线全为0的,其他位置为1的矩阵
mask_dui_jiao_0 = torch.ones(n ,n) - torch.eye(n, n )
#这步给相似度矩阵求exp,并且除以温度参数T
similarity_matrix = torch.exp(similarity_matrix/T)
#这步将相似度矩阵的对角线上的值全置0,因为对比损失不需要自己与自己的相似度
similarity_matrix = similarity_matrix*mask_dui_jiao_0
#这步产生了相同类别的相似度矩阵,标签相同的位置保存它们的相似度,其他位置都是0,对角线上也为0
sim = mask*similarity_matrix
#用原先的对角线为0的相似度矩阵减去相同类别的相似度矩阵就是不同类别的相似度矩阵
no_sim = similarity_matrix - sim
#把不同类别的相似度矩阵按行求和,得到的是对比损失的分母(还差一个与分子相同的那个相似度,后面会加上)
no_sim_sum = torch.sum(no_sim , dim=1)
'''
将上面的矩阵扩展一下,再转置,加到sim(也就是相同标签的矩阵上),然后再把sim矩阵与sim_num矩阵做除法。
至于为什么这么做,就是因为对比损失的分母存在一个同类别的相似度,就是分子的数据。做了除法之后,就能得到
每个标签相同的相似度与它不同标签的相似度的值,它们在一个矩阵(loss矩阵)中。
'''
no_sim_sum_expend = no_sim_sum.repeat(n, 1).T
sim_sum = sim + no_sim_sum_expend
loss = torch.div(sim , sim_sum)
'''
由于loss矩阵中,存在0数值,那么在求-log的时候会出错。这时候,我们就将loss矩阵里面为0的地方
全部加上1,然后再去求loss矩阵的值,那么-log1 = 0 ,就是我们想要的。
'''
loss = mask_no_sim + loss + torch.eye(n, n )
#接下来就是算一个批次中的loss了
loss = -torch.log(loss) #求-log
loss = torch.sum(torch.sum(loss, dim=1) )/(2*n) #将所有数据都加起来除以2n
print(loss) #0.9821
#最后一步也可以写为---建议用这个, (len(torch.nonzero(loss)))表示一个批次中样本对个数的一半
loss = torch.sum(torch.sum(loss, dim=1)) / (len(torch.nonzero(loss)))
END
大致实现过程就是这样,如果有什么问题可以随时提出。或者有什么更好的实现方法,也欢迎共享。若你要使用该损失发文章,请引用:
“Chen, L., Wang, F., Yang, R. et al. Representation learning from noisy user-tagged data for sentiment classification. Int. J. Mach. Learn. & Cyber. (2022). https://doi.org/10.1007/s13042-022-01622-7”