对 NT-Xent 损失的直观解释,并逐步解释操作和我们在 PyTorch 中的实现
先来看一个公式
l
i
,
j
=
−
log
exp
(
sin
(
z
i
,
z
j
)
/
τ
)
∑
k
=
1
2
N
1
[
k
≠
i
]
exp
(
sin
(
z
i
,
z
k
)
/
τ
)
\mathbb{l}_{i,j}=-\log\frac{\exp(\sin(\mathbf{z}_i,\mathbf{z}_j)/\tau)}{\sum_{k=1}^{2N}1_{[k\neq i]}\exp(\sin(\mathbf{z}_i,\mathbf{z}_k)/\tau)}
li,j=−log∑k=12N1[k=i]exp(sin(zi,zk)/τ)exp(sin(zi,zj)/τ)
NT-Xent 损失
在较高层次上,对比学习模型的输入来自 N 个底层图像的 2N 个图像。N 个底层图像中的每一个都使用一组随机图像增强进行增强,以生成 2 N个增强图像。这就是我们最终在输入模型的单个训练批次中获得 2N 个图像的方式。
PyTorch 中 NT-Xent 损失的实现
网上看到的许多NT-Xent 丢失的实现都是从头开始实现所有操作,他们中的一些人实现损失函数的效率很低,更喜欢使用for 循环而不是 GPU 并行性。相反,我们将使用不同的方法。我们将根据 PyTorch 已经提供的标准交叉熵损失来实现此损失。为此,我们需要以 cross_entropy 可以接受的格式处理预测和真实标签。下面让我们看看如何执行此操作。
预测张量:首先,我们需要创建一个 PyTorch 张量来表示对比学习模型的输出。假设我们的批量大小为 8 (一张图片进行两次变换,所以2N=8)。我们将输入变量称为“x”。然后对x进行升维操作,计算时应用了tensor的广播性质。这里的具体操作可以看我的另一篇博客:PyTorch 中所有样本对的余弦相似度快速计算
x = torch.randn(8, 2)
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
如上所述,我们需要忽略每个特征向量的自相似性得分,因为它对模型的学习没有贡献,并且当我们想要计算交叉熵损失时,它会成为不必要的麻烦。为此,我们将定义一个变量“eye”,它是一个矩阵,主对角线上的元素值为 1.0,其余元素值为 0.0。我们可以使用以下命令创建这样的矩阵。
eye = torch.eye(8)
eye = eye.bool()#将其转换为布尔矩阵,以便我们可以使用此掩码矩阵索引到“xcs”变量。
y = xcs.clone()#将张量“xcs”克隆到名为“y”的张量中,以便稍后可以引用“xcs”张量。
y[eye] = float("-inf")#沿所有对余弦相似度矩阵的主对角线的值设置为-inf,这样当我们计算每行的 softmax 时,该值将不会产生任何影响。e的负无穷次方为0
ground truth (target tensor):对于我们使用的示例(2N=8),真实标签的样子如下:
tensor([1,0,3,2,5,4,7,6])
很难理解?这是因为张量“y”中的以下索引对包含正对。这里需要对F.cosine_similarity()函数做一定了解,他有两个重要参数,(input,target)一般用全连接层的输出做input,一般为一个二维数组,形状为[batch_size,class_num]。含义是第i个样本为第j类的概率。target表示对应的真实标签的下标索引,所以input[i][target[i]]
表示第i个样本预测正确的概率。这里可以参考这篇文章:【pytorch】交叉熵损失函数 F.cross_entropy()
上面的target张量在计算过程中的作用可以用这张图表示:
表示取图中打勾的索引元素,打勾的代表彼此互为正样本,即:
(0, 1), (1, 0)
(2, 3), (3, 2)
(4, 5), (5, 4)
(6, 7), (7, 6)
为了创建上面的张量,我们可以使用以下 PyTorch 代码,它将ground truth标签存储在变量“target”中。
target = torch.arange(8)
target[0::2] += 1
target[1::2] -= 1
交叉熵损失:我们拥有计算损失所需的所有材料了!唯一要做的就是调用 PyTorch 中的 cross_entropy API。
loss = F.cross_entropy(xcs / temperature, target, reduction="mean")
整合以上代码:
def nt_xent_loss(x, temperature):
assert len(x.size()) == 2
# Cosine similarity
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
xcs[torch.eye(x.size(0)).bool()] = float("-inf")
# Ground truth labels
target = torch.arange(8)
target[0::2] += 1
target[1::2] -= 1
# Standard cross-entropy loss
return F.cross_entropy(xcs / temperature, target, reduction="mean")
以上文章主要翻译自:NT-Xent (Normalized Temperature-Scaled Cross-Entropy) Loss Explained and Implemented in PyTorch感兴趣的可以看看原文,绝对精彩!