NT-Xent(归一化温标交叉熵)损失在 PyTorch 中解释和实现

对 NT-Xent 损失的直观解释,并逐步解释操作和我们在 PyTorch 中的实现

先来看一个公式
l i , j = − log ⁡ exp ⁡ ( sin ⁡ ( z i , z j ) / τ ) ∑ k = 1 2 N 1 [ k ≠ i ] exp ⁡ ( sin ⁡ ( z i , z k ) / τ ) \mathbb{l}_{i,j}=-\log\frac{\exp(\sin(\mathbf{z}_i,\mathbf{z}_j)/\tau)}{\sum_{k=1}^{2N}1_{[k\neq i]}\exp(\sin(\mathbf{z}_i,\mathbf{z}_k)/\tau)} li,j=logk=12N1[k=i]exp(sin(zi,zk)/τ)exp(sin(zi,zj)/τ)

NT-Xent 损失

    在较高层次上,对比学习模型的输入来自 N 个底层图像的 2N 个图像。N 个底层图像中的每一个都使用一组随机图像增强进行增强,以生成 2 N个增强图像。这就是我们最终在输入模型的单个训练批次中获得 2N 个图像的方式。
在这里插入图片描述

PyTorch 中 NT-Xent 损失的实现

    网上看到的许多NT-Xent 丢失的实现都是从头开始实现所有操作,他们中的一些人实现损失函数的效率很低,更喜欢使用for 循环而不是 GPU 并行性。相反,我们将使用不同的方法。我们将根据 PyTorch 已经提供的标准交叉熵损失来实现此损失。为此,我们需要以 cross_entropy 可以接受的格式处理预测和真实标签。下面让我们看看如何执行此操作。

预测张量:首先,我们需要创建一个 PyTorch 张量来表示对比学习模型的输出。假设我们的批量大小为 8 (一张图片进行两次变换,所以2N=8)。我们将输入变量称为“x”。然后对x进行升维操作,计算时应用了tensor的广播性质。这里的具体操作可以看我的另一篇博客:PyTorch 中所有样本对的余弦相似度快速计算

x = torch.randn(8, 2)
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)

    如上所述,我们需要忽略每个特征向量的自相似性得分,因为它对模型的学习没有贡献,并且当我们想要计算交叉熵损失时,它会成为不必要的麻烦。为此,我们将定义一个变量“eye”,它是一个矩阵,主对角线上的元素值为 1.0,其余元素值为 0.0。我们可以使用以下命令创建这样的矩阵。


eye = torch.eye(8)
eye = eye.bool()#将其转换为布尔矩阵,以便我们可以使用此掩码矩阵索引到“xcs”变量。
y = xcs.clone()#将张量“xcs”克隆到名为“y”的张量中,以便稍后可以引用“xcs”张量。
y[eye] = float("-inf")#沿所有对余弦相似度矩阵的主对角线的值设置为-inf,这样当我们计算每行的 softmax 时,该值将不会产生任何影响。e的负无穷次方为0

ground truth (target tensor):对于我们使用的示例(2N=8),真实标签的样子如下:

tensor([1,0,3,2,5,4,7,6])

    很难理解?这是因为张量“y”中的以下索引对包含正对。这里需要对F.cosine_similarity()函数做一定了解,他有两个重要参数,(input,target)一般用全连接层的输出做input,一般为一个二维数组,形状为[batch_size,class_num]。含义是第i个样本为第j类的概率。target表示对应的真实标签的下标索引,所以input[i][target[i]]表示第i个样本预测正确的概率。这里可以参考这篇文章:【pytorch】交叉熵损失函数 F.cross_entropy()

上面的target张量在计算过程中的作用可以用这张图表示:
在这里插入图片描述
    表示取图中打勾的索引元素,打勾的代表彼此互为正样本,即:

(0, 1), (1, 0)
(2, 3), (3, 2)
(4, 5), (5, 4)
(6, 7), (7, 6)

    为了创建上面的张量,我们可以使用以下 PyTorch 代码,它将ground truth标签存储在变量“target”中。

target = torch.arange(8)
target[0::2] += 1
target[1::2] -= 1

交叉熵损失:我们拥有计算损失所需的所有材料了!唯一要做的就是调用 PyTorch 中的 cross_entropy API。

loss = F.cross_entropy(xcs  / temperature, target, reduction="mean")

整合以上代码:

def nt_xent_loss(x, temperature):
  assert len(x.size()) == 2

  # Cosine similarity
  xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
  xcs[torch.eye(x.size(0)).bool()] = float("-inf")

  # Ground truth labels
  target = torch.arange(8)
  target[0::2] += 1
  target[1::2] -= 1

  # Standard cross-entropy loss
  return F.cross_entropy(xcs / temperature, target, reduction="mean")

以上文章主要翻译自:NT-Xent (Normalized Temperature-Scaled Cross-Entropy) Loss Explained and Implemented in PyTorch感兴趣的可以看看原文,绝对精彩!

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值