对比学习和自监督学习中的InfoNCE损失

InfoNCE(Noise Contrastive Estimation)和交叉熵损失都是两个关键的概念。它们不仅在衡量概率分布之间的差异方面发挥着重要作用,而且在深度学习的自监督学习领域扮演着重要角色。虽然它们的形式和应用环境有所不同,但是我们可以发现它们之间存在着微妙的联系。

InfoNCE Loss(Noise Contrastive Estimation Loss)是一种用于自监督学习的损失函数,通常用于学习特征表示或者表征学习。它基于信息论的思想,通过对比正样本和负样本的相似性来学习模型参数。

InfoNCE Loss 公式

InfoNCE 损失的计算公式如下:
L N C E = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( P i , p o s / τ ) ∑ k exp ⁡ ( P i , k / τ ) \mathcal{L}_{NCE} = -\frac{1}{N} \sum_{i=1}^N \log\frac{\exp(P_{i, pos}/\tau)}{\sum_k \exp(P_{i,k}/\tau)} LNCE=N1i=1Nlogkexp(Pi,k/τ)exp(Pi,pos/τ)

其中 P i , p o s P_{i,pos} Pi,pos表示第i个样本与其正样本的相似性/距离,log右侧整体为正样本的概率分布。

CLIP中用到的对比损失就是典型的InfoNCE Loss,其正样本对在相似性矩阵的对角线上。
CLIP f ig

CLIP algo

InfoNCE损失在自监督学习场景下也发挥着重要的作用。在此以多视角自监督对比学习举例。

对于一个批次的数据 X ∈ R n × c × w × h X\in \mathbb{R}^{n\times c\times w\times h} XRn×c×w×h,我们计算其两个视角的图像特征 F , F ′ F,F' F,F。随后计算相似性矩阵 S = F ⋅ F T S = F\cdot F^T S=FFT,其中 F ∈ R n × d F \in \mathbb{R}^{n\times d} FRn×d

S S S中,对于一个样本数据 x i x_i xi,其正样本为它的第二个视角数据 x i ′ x_i' xi 以及其自身。得到 S S S之后,我们需要对其进行修改,抹去对角线元素,拉近样本与其第二个视角数据的相似性。
Similarity Changes

在上图所示的相似性矩阵中,左上角表示view_1的数据与view_1的数据的相似性矩阵,右上角表示view_1的数据与view_2的数据的相似性矩阵,其他以此类推。数字0,512,511表示 S S S的维度。

然后我们从中筛选出正样本对的相似性分数向量 p o s ∈ R ( n − 1 ) × 1 pos \in \mathbb{R}^{(n-1)\times1} posR(n1)×1,负样本对的相似性分数向量 n e g = ∈ R ( n − 1 ) × ( n − 1 ) neg = \in \mathbb{R}^{(n-1) \times (n-1)} neg=∈R(n1)×(n1),构成概率分布 logits = [ p o s , n e g ] ∈ R ( n − 1 ) × n \text{logits} = [pos, neg] \in \mathbb{R}^{(n-1) \times n} logits=[pos,neg]R(n1)×n。因为正样本设置在logits开头处,我们构造标签为长度为 n 的向量 y ∈ R ( n − 1 ) y \in \mathbb{R}^{(n-1)} yR(n1),其中所有元素为0,计算logits与y的CE损失即可得到目标损失。

具体代码实现

def info_nce_logits(features, batch_size, n_views=2, temperature=1.0):
    """
    It is assumed that features are aggregated so the first of all the images is first, 
    then the second view of all images and so on
    So e.g. for args.n_views == 3, features = [x_1, x_2, ..., x_1', x_2', ...,  x_1'', x_2'', ...]
    """
    device = features.device
    labels = torch.eye(batch_size, dtype=torch.bool, device=device).repeat(n_views, n_views)
    # labels is a correspondence matrix: do the features come from the same image?

    features = F.normalize(features, dim=1)
    similarity_matrix = torch.matmul(features, features.T)

    # discard the main diagonal from both: labels and similarities matrix
    mask = ~torch.eye(labels.shape[0], dtype=torch.bool, device=device) # False on the diagonal, True elsewhere
    labels = labels[mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[mask].view(similarity_matrix.shape[0], -1)

    # select and combine multiple positives
    positives = similarity_matrix[labels].view(labels.shape[0], -1)

    # select only the negatives the negatives
    negatives = similarity_matrix[~labels].view(labels.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long, device=device)

    logits = logits / temperature
    return logits, labels

Reference

Original Repo of Code

InfoNCE Loss公式及源码理解

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值