对比学习损失函数中超参数temperature的作用

背景

最近在看凯明大神的对比学习MOCO时,看到infoNCE loss的公式时,对其中参数T(应该是tao,打不出来,就浅用T代替一下)有点费解,于是查阅了一些资料,记录一下自己的理解。

首先,附上infoNCE loss的公式。

L q = − l o g e x p ( q ⋅ k + / τ ) ∑ i = 0 k e x p ( q ⋅ k i / τ ) L_q = -log\frac{exp(q\cdot k_+ / \tau )}{\sum_{i=0}^{k}exp(q\cdot k_i / \tau)} Lq=logi=0kexp(qki/τ)exp(qk+/τ)
在其他地方看到这个公式的另一种写法,感觉更容易理解,也更加一般。

L ( x i ) = − l o g e x p ( s i , i / τ ) ∑ k ≠ i e x p ( s i , k / τ ) + e x p ( s i , i / τ ) L(x_i) = -log\frac{exp(s_{i,i}/ \tau )}{\sum_{k\neq i}exp(s_{i,k} / \tau) + exp(s_{i,i} / \tau)} L(xi)=logk=iexp(si,k/τ)+exp(si,i/τ)exp(si,i/τ)
简单的来说,这个公式就是熟悉的cross entropy loss(交叉熵损失)的一个变体。 s i , i s_{i,i} si,i是指当前特征与正样本间的相似度,这个相似度可以用点乘,也可以用其他方式计算。 s i , k s_{i,k} si,k是指当前特征与负样本之间的相似度。

其次,在对比学习中,每张图片相当于一个类别,对于每张图片,通过对自身数据增强后的图片为正例,其余所有图片都是负例。对比学习的目的是尽量使正例之间的相似度相近,且负例之间的相似度越低越好。换一句话说,就是要训练一个特征提取网络,使得所有图片在特征空间中的特征向量都尽可能的分开。下面是在InstDisc文中一副对比学习图片。

在这里插入图片描述

感谢以下大佬的博文和论文,本文是在他们的基础上写作的
https://blog.csdn.net/qq_36560894/article/details/114874268
(CVPR2021)理解对比损失的性质以及温度系数的作用:arxiv
对上面论文的理解:知乎

超参数temperature的直观理解

由于infoNCE loss是交叉熵损失的一个变体,为了更加直观的理解,我们先在交叉熵损失中加入temperature,看一下有什么样的效果。

假设一个三分类的问题,预测图片是猫,狗还是猪。特征提取到最后一层,输出为[1 ,2, 3],假设预测正确,结果确实是猪,那么交叉熵应当这样计算。

import torch
import torch.nn as nn
import torch.nn.functional as F

criterion = nn.CrossEntropyLoss()
x = torch.Tensor([[1, 2, 3]])
y = torch.Tensor([2]).type(torch.long)

# 当temperature=1时
t = 1
out = F.softmax(x/t, dim = 1)
print("after softmax:"+str(out))
loss = criterion(x, y)
print("loss:"+str(loss))

# 输出为
# after softmax:tensor([[0.0900, 0.2447, 0.6652]])
# loss:tensor(0.4076)

# 当temperature=0.5时
t = 0.5
out = F.softmax(x/t, dim = 1)
print("after softmax:"+str(out))
loss = criterion(x, y)
print("loss:"+str(loss))

# 输出为
# after softmax:tensor([[0.0159, 0.1173, 0.8668]])
# loss:tensor(0.1429)

# 当temperature=0.1时
t = 0.1
out = F.softmax(x/t, dim = 1)
print("after softmax:"+str(out))
loss = criterion(x, y)
print("loss:"+str(loss))

# 输出为
# after softmax:tensor([[2.0611e-09, 4.5398e-05, 9.9995e-01]])
# loss:tensor(4.5418e-05)

不难看出,当分类结果正确时,当temperature越小时,softmax输出各类别的分数差别越大,loss越小。

可以在尝试当分类结果错误,比如当正确分类结果为0时,即y=torch.Tensor([0]).type(torch.long)时,有着这样的规律:当分类结果错误时,当temperature越小时,softmax输出各类别的分数差别越大,loss越大。

对比学习中的temperature参数理解

讲完了对比学习的背景和超参数T的直观理解,我们有如下结论:

(1)对比学习的目的是训练一个特征提取网络,使得所有特征向量在特征空间中尽可能的远离。

(2)当分类结果错误时,当temperature越小时,softmax输出各类别的分数差别越大,loss越大。

我们下面开始讲在对比学习loss中加入temperature参数解决的核心问题:困难负样本问题。

困难负样本,就是一张图像经过特征提取网络后,发现自己相较于自身数据增强后的图片特征,更相似于其他图片提取出的特征。但是,相似度并没有差很多。这样的样本我们就叫他困难负样本。

如果没有引入temperature参数,当有困难负样本过来时,loss相对较小,对参数的惩罚也就较小。由于我们希望所有特征向量尽量远离,因此,必须对所有错误分类的样本都加大惩罚,所以,要加入一个小于1的temperature参数,来放大对于困难负样本的惩罚。

讲到这,对比学习中的temperature参数其实就已经讲的差不多了,下面再略微提一下Uniformity-Tolerance Dilemma,也就是均匀性-容忍性困境。

这里又要说到对比学习的目标,是通过大规模自监督学习去训练一个能够很好提取特征的特征提取网络,说到底,就是一个代理任务,这个使所有图片特征尽量分开的任务本身是没有任何意义的,只是用来去训练特征提取网络。训练好这个特征提取网络后,就可以加上不同的检测头来执行一系列的下游任务,如检测、分割等。

说回temperature参数,考虑一下出现困难负样本的原因,有可能是因为两张图片确实非常相似,通常是两张图片有着相同的前景,让算法产生了混淆。也就是说,其实网络已经学到了一定的语义特征,这对下游任务是有帮助的,强行将两张非常相似图片提取出的特征相互远离,有可能打破这种语义信息,导致在执行下游任务时,效果不升反降。

因此,调temperature参数是一个很讲究的事情,太高不能很好的训练特征提取网络,太低又会打破模型学到的语义信息,损害下游任务的准确度。
这种语义信息,导致在执行下游任务时,效果不升反降。

因此,调temperature参数是一个很讲究的事情,太高不能很好的训练特征提取网络,太低又会打破模型学到的语义信息,损害下游任务的准确度。

### 对比学习损失函数在聚类中的应用 对于聚类任务而言,对比学习通过构建正样本对和负样本对来优化模型表示能力。具体到聚类场景下,对比学习的目标是在嵌入空间中拉近属于同一簇的数据点(即正样本),而推远来自不同簇的数据点(即负样本)。为了实现这一目标,通常采用的损失函数形式可以概括为: #### InfoNCE Loss 一种广泛使用的对比损失函数InfoNCE (Noise Contrastive Estimation),其定义如下: \[ \mathcal{L}_{\text {infoNCE }}=\sum_{i=1}^{N}-\log \frac{\exp (\operatorname{sim}\left(\boldsymbol{z}_i, \boldsymbol{z}_j\right) / \tau)}{\sum_{k=1}^{K} I(k \neq j) \exp (\operatorname{sim}\left(\boldsymbol{z}_i, \boldsymbol{z}_k\right) / \tau)} \] 其中 \( z_i \) 和 \( z_j \) 是同一个实例经过不同变换得到的两个视图对应的表征向量;\( sim() \) 表示相似度度量方式,比如余弦相似度;\( τ \) 则是一个温度参数用来调整分布锐利程度[^1]。 然而,在实际操作过程中发现仅依靠上述标准对比损失难以获得理想的聚类效果,因为这需要大量精心挑选的负样本来维持良好的互信息边界。针对这个问题,研究者们提出了多种改进方案,例如引入额外的正则化项以增强特征表达力或减少对抗噪声干扰等。 #### 正则化损失 Lreg 考虑到直接最大化簇间距离可能带来过拟合风险,有工作建议加入一个专门设计的正则化损失 \( L_{\mathrm{reg}} \),旨在扩大不同基底间的高维特征差异的同时保持一定平滑性。该损失基于成对样本之间余弦相似度 s_ij 的计算,并试图最小化跨类别样本间的这种相似度得分,从而促进更清晰可分的集群结构形成[^5]。 ```python import torch from torch.nn.functional import normalize def info_nce_loss(z_i, z_j, temperature=0.5): """Compute the InfoNCE loss.""" batch_size = z_i.shape[0] # Normalize representations to unit vectors z_i_norm = normalize(z_i) z_j_norm = normalize(z_j) # Compute pairwise cosine similarities between all pairs of samples logits_ii = torch.mm(z_i_norm, z_i_norm.t()) / temperature logits_ij = torch.mm(z_i_norm, z_j_norm.t()) / temperature mask = torch.eye(batch_size).to(logits_ii.device) positives = logits_ij.diag() negatives = torch.cat([logits_ii[mask==0], logits_ij[mask==0]], dim=-1) nominator = torch.exp(positives) denominator = nominator + torch.sum(torch.exp(negatives), dim=-1) return -torch.mean(torch.log(nominator/denominator)) ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值