对比学习中常用的NT-Xent(Normalized Temperature-Scaled Cross-Entropy) Loss以及NT-BXent(Normalized Temperature-Scaled Binary Cross-Entropy) Loss。
NT-Xent Loss将所有非目标样本视为负样本,NT-BXent Loss将所有非目标样本且非同类样本视为负样本。
下面给出NT-BXent Loss的python实现,可直接使用,代码参考链接。该链接详细解释了计算过程,感兴趣的可以看看。
class NT_BXENT_LOSS(nn.Module):
def __init__(self):
super(NT_BXENT_LOSS,self).__init__()
def forward(self, x, pos_indices, temperature):
assert len(x.size()) == 2
# Add indexes of the principal diagonal elements to pos_indices
pos_indices = torch.cat([
pos_indices,
torch.arange(x.size(0)).reshape(x.size(0), 1).expand(-1, 2),
], dim=0)
# Ground truth labels
target = torch.zeros(x.size(0), x.size(0))
target[pos_indices[:,0], pos_indices[:,1]] = 1.0
target = target.cuda()
# Cosine similarity
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
# Set logit of diagonal element to "inf" signifying complete
# correlation. sigmoid(inf) = 1.0 so this will work out nicely
# when computing the Binary cross-entropy Loss.
xcs[torch.eye(x.size(0)).bool()] = float("inf")
# Standard binary cross-entropy loss. We use binary_cross_entropy() here and not
# binary_cross_entropy_with_logits() because of
# https://github.com/pytorch/pytorch/issues/102894
# The method *_with_logits() uses the log-sum-exp-trick, which causes inf and -inf values
# to result in a NaN result.
loss = F.binary_cross_entropy((xcs / temperature).sigmoid(), target, reduction="none")
target_pos = target.bool()
target_neg = ~target_pos
pos_zero = torch.zeros(x.size(0), x.size(0))
neg_zero = torch.zeros(x.size(0), x.size(0))
pos_zero, neg_zero = pos_zero.cuda(), neg_zero.cuda()
loss_pos = pos_zero.masked_scatter(target_pos, loss[target_pos])
loss_neg = neg_zero.masked_scatter(target_neg, loss[target_neg])
loss_pos = loss_pos.sum(dim=1)
loss_neg = loss_neg.sum(dim=1)
num_pos = target.sum(dim=1)
num_neg = x.size(0) - num_pos
return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()