对比学习损失函数中超参数temperature的作用

背景

最近在看凯明大神的对比学习MOCO时,看到infoNCE loss的公式时,对其中参数T(应该是tao,打不出来,就浅用T代替一下)有点费解,于是查阅了一些资料,记录一下自己的理解。

首先,附上infoNCE loss的公式。

L q = − l o g e x p ( q ⋅ k + / τ ) ∑ i = 0 k e x p ( q ⋅ k i / τ ) L_q = -log\frac{exp(q\cdot k_+ / \tau )}{\sum_{i=0}^{k}exp(q\cdot k_i / \tau)} Lq=logi=0kexp(qki/τ)exp(qk+/τ)
在其他地方看到这个公式的另一种写法,感觉更容易理解,也更加一般。

L ( x i ) = − l o g e x p ( s i , i / τ ) ∑ k ≠ i e x p ( s i , k / τ ) + e x p ( s i , i / τ ) L(x_i) = -log\frac{exp(s_{i,i}/ \tau )}{\sum_{k\neq i}exp(s_{i,k} / \tau) + exp(s_{i,i} / \tau)} L(xi)=logk=iexp(si,k/τ)+exp(si,i/τ)exp(si,i/τ)
简单的来说,这个公式就是熟悉的cross entropy loss(交叉熵损失)的一个变体。 s i , i s_{i,i} si,i是指当前特征与正样本间的相似度,这个相似度可以用点乘,也可以用其他方式计算。 s i , k s_{i,k} si,k是指当前特征与负样本之间的相似度。

其次,在对比学习中,每张图片相当于一个类别,对于每张图片,通过对自身数据增强后的图片为正例,其余所有图片都是负例。对比学习的目的是尽量使正例之间的相似度相近,且负例之间的相似度越低越好。换一句话说,就是要训练一个特征提取网络,使得所有图片在特征空间中的特征向量都尽可能的分开。下面是在InstDisc文中一副对比学习图片。

在这里插入图片描述

感谢以下大佬的博文和论文,本文是在他们的基础上写作的
https://blog.csdn.net/qq_36560894/article/details/114874268
(CVPR2021)理解对比损失的性质以及温度系数的作用:arxiv
对上面论文的理解:知乎

超参数temperature的直观理解

由于infoNCE loss是交叉熵损失的一个变体,为了更加直观的理解,我们先在交叉熵损失中加入temperature,看一下有什么样的效果。

假设一个三分类的问题,预测图片是猫,狗还是猪。特征提取到最后一层,输出为[1 ,2, 3],假设预测正确,结果确实是猪,那么交叉熵应当这样计算。

import torch
import torch.nn as nn
import torch.nn.functional as F

criterion = nn.CrossEntropyLoss()
x = torch.Tensor([[1, 2, 3]])
y = torch.Tensor([2]).type(torch.long)

# 当temperature=1时
t = 1
out = F.softmax(x/t, dim = 1)
print("after softmax:"+str(out))
loss = criterion(x, y)
print("loss:"+str(loss))

# 输出为
# after softmax:tensor([[0.0900, 0.2447, 0.6652]])
# loss:tensor(0.4076)

# 当temperature=0.5时
t = 0.5
out = F.softmax(x/t, dim = 1)
print("after softmax:"+str(out))
loss = criterion(x, y)
print("loss:"+str(loss))

# 输出为
# after softmax:tensor([[0.0159, 0.1173, 0.8668]])
# loss:tensor(0.1429)

# 当temperature=0.1时
t = 0.1
out = F.softmax(x/t, dim = 1)
print("after softmax:"+str(out))
loss = criterion(x, y)
print("loss:"+str(loss))

# 输出为
# after softmax:tensor([[2.0611e-09, 4.5398e-05, 9.9995e-01]])
# loss:tensor(4.5418e-05)

不难看出,当分类结果正确时,当temperature越小时,softmax输出各类别的分数差别越大,loss越小。

可以在尝试当分类结果错误,比如当正确分类结果为0时,即y=torch.Tensor([0]).type(torch.long)时,有着这样的规律:当分类结果错误时,当temperature越小时,softmax输出各类别的分数差别越大,loss越大。

对比学习中的temperature参数理解

讲完了对比学习的背景和超参数T的直观理解,我们有如下结论:

(1)对比学习的目的是训练一个特征提取网络,使得所有特征向量在特征空间中尽可能的远离。

(2)当分类结果错误时,当temperature越小时,softmax输出各类别的分数差别越大,loss越大。

我们下面开始讲在对比学习loss中加入temperature参数解决的核心问题:困难负样本问题。

困难负样本,就是一张图像经过特征提取网络后,发现自己相较于自身数据增强后的图片特征,更相似于其他图片提取出的特征。但是,相似度并没有差很多。这样的样本我们就叫他困难负样本。

如果没有引入temperature参数,当有困难负样本过来时,loss相对较小,对参数的惩罚也就较小。由于我们希望所有特征向量尽量远离,因此,必须对所有错误分类的样本都加大惩罚,所以,要加入一个小于1的temperature参数,来放大对于困难负样本的惩罚。

讲到这,对比学习中的temperature参数其实就已经讲的差不多了,下面再略微提一下Uniformity-Tolerance Dilemma,也就是均匀性-容忍性困境。

这里又要说到对比学习的目标,是通过大规模自监督学习去训练一个能够很好提取特征的特征提取网络,说到底,就是一个代理任务,这个使所有图片特征尽量分开的任务本身是没有任何意义的,只是用来去训练特征提取网络。训练好这个特征提取网络后,就可以加上不同的检测头来执行一系列的下游任务,如检测、分割等。

说回temperature参数,考虑一下出现困难负样本的原因,有可能是因为两张图片确实非常相似,通常是两张图片有着相同的前景,让算法产生了混淆。也就是说,其实网络已经学到了一定的语义特征,这对下游任务是有帮助的,强行将两张非常相似图片提取出的特征相互远离,有可能打破这种语义信息,导致在执行下游任务时,效果不升反降。

因此,调temperature参数是一个很讲究的事情,太高不能很好的训练特征提取网络,太低又会打破模型学到的语义信息,损害下游任务的准确度。
这种语义信息,导致在执行下游任务时,效果不升反降。

因此,调temperature参数是一个很讲究的事情,太高不能很好的训练特征提取网络,太低又会打破模型学到的语义信息,损害下游任务的准确度。

  • 5
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
自然语言处理(Natural Language Processing,NLP)领域对比学习损失函数是一种用于训练模型的损失函数,它主要用于学习将不同样本进行比较和分类的能力。对比学习损失函数的目标是通过最大化正样本之间的相似性,并最小化负样本之间的相似性来训练模型。 在NLP领域,常用的对比学习损失函数有以下几种: 1. 余弦相似度损失(Cosine Similarity Loss):该损失函数通过计算正样本和负样本之间的余弦相似度来衡量它们之间的相似性。常用的余弦相似度损失函数包括三元组损失(Triplet Loss)和N元组损失(N-Tuple Loss)。 2. 对比损失(Contrastive Loss):该损失函数通过最小化正样本和负样本之间的欧氏距离或曼哈顿距离来衡量它们之间的差异。对比损失函数常用于学习将两个样本映射到低维空间,并使得同类样本之间的距离尽可能小,异类样本之间的距离尽可能大。 3. 三元组损失(Triplet Loss):该损失函数通过最小化正样本和负样本之间的距离差异来衡量它们之间的相似性。三元组损失函数常用于学习将一个样本与其正样本和负样本进行比较,并使得正样本与该样本之间的距离小于负样本与该样本之间的距离。 4. 交叉熵损失(Cross-Entropy Loss):该损失函数常用于分类任务,在对比学习可以用于衡量正样本和负样本之间的差异。交叉熵损失函数通过计算模型预测结果与真实标签之间的差异来衡量模型的性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值