对比学习损失函数 - InfoNCE

InfoNCE Loss :构建高效对比学习模型

引言

对比学习中的InfoNCE损失函数是自监督学习领域的重要进展,它通过最大化正样本对之间的相似度并最小化负样本对的相似度,有效地引导模型学习到数据的本质特征。InfoNCE不仅提高了表示学习的质量,还为下游任务如分类、聚类等提供了强大的基础。

在这里插入图片描述


一、背景

InfoNCE(Information Noise-Contrastive Estimation)损失是对比学习中非常重要的一个损失函数,特别是在SimCLR、MoCo等框架中被广泛应用。它通过最大化正样本对之间的相似度,同时最小化负样本对之间的相似度来学习有用的表示。下面我们详细解释InfoNCE损失的公式及其各个组成部分。
在这里插入图片描述

二、公式

1. 定义

L InfoNCE = − log ⁡ exp ⁡ ( sim ( z i , z j ) / τ ) ∑ k = 1 2 N 1 [ k ≠ i ] exp ⁡ ( sim ( z i , z k ) / τ ) \mathcal{L}_{\text{InfoNCE}} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbf{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k) / \tau)} LInfoNCE=logk=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)

2. 公式分解

为了帮助你更深入地理解InfoNCE损失公式,我们将逐步分解并解释每个部分的意义。我们通过一个直观的例子来说明这个过程,确保每个步骤都清晰明了。

2.1 分子:正样本对的相似度得分

exp ⁡ ( sim ( z i , z j ) / τ ) \exp(\text{sim}(z_i, z_j) / \tau) exp(sim(zi,zj)/τ)

  • z i z_i zi z j z_j zj 是来自同一数据点的两个不同增强视图(augmented views)。
  • sim ( z i , z j ) \text{sim}(z_i, z_j) sim(zi,zj) 衡量这两个嵌入向量之间的相似度,通常使用余弦相似度或点积。
  • τ \tau τ 是温度参数,控制相似度分布的锐度。
  • exp ⁡ ( ⋅ ) \exp(\cdot) exp() 是指数函数,用于放大相似度得分。

示例
假设 sim ( z 1 , z 3 ) = 0.9 \text{sim}(z_1, z_3) = 0.9 sim(z1,z3)=0.9,温度参数 τ = 0.1 \tau = 0.1 τ=0.1,那么:

exp ⁡ ( sim ( z 1 , z 3 ) / 0.1 ) = exp ⁡ ( 0.9 / 0.1 ) = exp ⁡ ( 9 ) ≈ 8103.08 \exp(\text{sim}(z_1, z_3) / 0.1) = \exp(0.9 / 0.1) = \exp(9) \approx 8103.08 exp(sim(z1,z3)/0.1)=exp(0.9/0.1)=exp(9)8103.08

这表示正样本对 ( z 1 , z 3 ) (z_1, z_3) (z1,z3) 的相似度得分为 8103.08。

2.2 分母:所有负样本对的相似度得分之和

∑ k = 1 2 N 1 [ k ≠ i ] exp ⁡ ( sim ( z i , z k ) / τ ) \sum_{k=1}^{2N} \mathbf{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k) / \tau) k=12N1[k=i]exp(sim(zi,zk)/τ)

  • 2 N 2N 2N 是批次中所有嵌入向量的数量(每个数据点有两个增强视图,因此总共有 2 N 2N 2N 个嵌入向量)。
  • 1 [ k ≠ i ] \mathbf{1}_{[k \neq i]} 1[k=i] 是指示函数,确保不将自身作为负样本。当 k = i k = i k=i 时,指示函数取值为0;否则为1。
  • exp ⁡ ( sim ( z i , z k ) / τ ) \exp(\text{sim}(z_i, z_k) / \tau) exp(sim(zi,zk)/τ) 是每个负样本对的相似度得分,同样经过温度参数调整后的指数形式。

示例
假设我们有4个嵌入向量 z 1 , z 2 , z 3 , z 4 z_1, z_2, z_3, z_4 z1,z2,z3,z4,并且 z 1 z_1 z1 z 3 z_3 z3 来自同一个数据点,其他都是负样本。我们需要计算 z 1 z_1 z1 对所有其他嵌入向量的相似度得分之和:

∑ k = 1 4 1 [ k ≠ 1 ] exp ⁡ ( sim ( z 1 , z k ) / 0.1 ) = exp ⁡ ( sim ( z 1 , z 2 ) / 0.1 ) + exp ⁡ ( sim ( z 1 , z 3 ) / 0.1 ) + exp ⁡ ( sim ( z 1 , z 4 ) / 0.1 ) \sum_{k=1}^{4} \mathbf{1}_{[k \neq 1]} \exp(\text{sim}(z_1, z_k) / 0.1) = \exp(\text{sim}(z_1, z_2) / 0.1) + \exp(\text{sim}(z_1, z_3) / 0.1) + \exp(\text{sim}(z_1, z_4) / 0.1) k=141[k=1]exp(sim(z1,zk)/0.1)=exp(sim(z1,z2)/0.1)+exp(sim(z1,z3)/0.1)+exp(sim(z1,z4)/0.1)

具体数值为:

exp ⁡

### 对比学习损失函数在聚类中的应用 对于聚类任务而言,对比学习通过构建正样本对和负样本对来优化模型表示能力。具体到聚类场景下,对比学习的目标是在嵌入空间中拉近属于同一簇的数据点(即正样本),而推远来自不同簇的数据点(即负样本)。为了实现这一目标,通常采用的损失函数形式可以概括为: #### InfoNCE Loss 一种广泛使用的对比损失函数InfoNCE (Noise Contrastive Estimation),其定义如下: \[ \mathcal{L}_{\text {infoNCE }}=\sum_{i=1}^{N}-\log \frac{\exp (\operatorname{sim}\left(\boldsymbol{z}_i, \boldsymbol{z}_j\right) / \tau)}{\sum_{k=1}^{K} I(k \neq j) \exp (\operatorname{sim}\left(\boldsymbol{z}_i, \boldsymbol{z}_k\right) / \tau)} \] 其中 \( z_i \) 和 \( z_j \) 是同一个实例经过不同变换得到的两个视图对应的表征向量;\( sim() \) 表示相似度度量方式,比如余弦相似度;\( τ \) 则是一个温度参数用来调整分布锐利程度[^1]。 然而,在实际操作过程中发现仅依靠上述标准对比损失难以获得理想的聚类效果,因为这需要大量精心挑选的负样本来维持良好的互信息边界。针对这个问题,研究者们提出了多种改进方案,例如引入额外的正则化项以增强特征表达力或减少对抗噪声干扰等。 #### 正则化损失 Lreg 考虑到直接最大化簇间距离可能带来过拟合风险,有工作建议加入一个专门设计的正则化损失 \( L_{\mathrm{reg}} \),旨在扩大不同基底间的高维特征差异的同时保持一定平滑性。该损失基于成对样本之间余弦相似度 s_ij 的计算,并试图最小化跨类别样本间的这种相似度得分,从而促进更清晰可分的集群结构形成[^5]。 ```python import torch from torch.nn.functional import normalize def info_nce_loss(z_i, z_j, temperature=0.5): """Compute the InfoNCE loss.""" batch_size = z_i.shape[0] # Normalize representations to unit vectors z_i_norm = normalize(z_i) z_j_norm = normalize(z_j) # Compute pairwise cosine similarities between all pairs of samples logits_ii = torch.mm(z_i_norm, z_i_norm.t()) / temperature logits_ij = torch.mm(z_i_norm, z_j_norm.t()) / temperature mask = torch.eye(batch_size).to(logits_ii.device) positives = logits_ij.diag() negatives = torch.cat([logits_ii[mask==0], logits_ij[mask==0]], dim=-1) nominator = torch.exp(positives) denominator = nominator + torch.sum(torch.exp(negatives), dim=-1) return -torch.mean(torch.log(nominator/denominator)) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

进一步有进一步的欢喜

您的鼓励将是我创作的最大动力~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值