对比学习损失函数用于在无监督或半监督的情况下学习数据表示,使得相似的数据样本在表示空间中更加接近,而不相似的样本更远离。以下是几种常见的对比学习损失函数及其详细说明:
一、对比损失(Contrastive Loss)
对比损失用于使得正样本对(相似样本对)在表示空间中接近,而负样本对(不相似样本对)远离。
1、公式
\[ L = \frac{1}{2N} \sum_{i=1}^{N} \left( y_i \cdot D_i^2 + (1 - y_i) \cdot \max(margin - D_i, 0)^2 \right) \]
其中:
\( y_i \) 是标签,1 表示正样本对,0 表示负样本对。
\( D_i \) 是样本对的欧氏距离。
\( margin \) 是一个超参数,表示负样本对之间的最小距离。
2、代码实现(PyTorch)
import torch
import torch.nn.functional as F
class ContrastiveLoss(nn.Module):
def __init__(self, margin=1.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2)
loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss
二、三元组损失(Triplet Loss)
三元组损失用于训练模型使得锚点样本(Anchor)和正样本(Positive)之间的距离小于锚点样本和负样本(Negative)之间的距离。
1、公式
\[ L = \sum_{i=1}^{N} \left[ \|f(x_i^a) - f(x_i^p)\|_2^2 - \|f(x_i^a) - f(x_i^n)\|_2^2 + \alpha \right]_+ \]
其中:
- \( x_i^a \) 是锚点样本。
- \( x_i^p \) 是正样本。
- \( x_i^n \) 是负样本。
- \( \alpha \) 是一个超参数,表示正负样本对之间的最小距离差。
2、代码实现(PyTorch)
import torch
import torch.nn.functional as F
class TripletLoss(nn.Module):
def __init__(self, margin=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
pos_distance = F.pairwise_distance(anchor, positive)
neg_distance = F.pairwise_distance(anchor, negative)
loss = torch.mean(F.relu(pos_distance - neg_distance + self.margin))
return loss
三、信息论对比损失(InfoNCE Loss)
InfoNCE 损失常用于自监督学习,通过最大化正样本对之间的相似度,同时最小化正样本对和负样本对之间的相似度。
1、公式
\[ L = - \frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(f(x_i) \cdot f(x_i^+))}{\exp(f(x_i) \cdot f(x_i^+)) + \sum_{j=1}^{K} \exp(f(x_i) \cdot f(x_j^-))} \]
其中:
- \( f(x_i) \) 是样本 \( x_i \) 的表示。
- \( x_i^+ \) 是正样本。
- \( x_j^- \) 是负样本。
- \( K \) 是负样本的数量。
2、代码实现(PyTorch)
import torch
import torch.nn.functional as F
class InfoNCELoss(nn.Module):
def __init__(self, temperature=0.1):
super(InfoNCELoss, self).__init__()
self.temperature = temperature
def forward(self, features, labels):
batch_size = features.size(0)
labels = labels.contiguous().view(-1, 1)
mask = torch.eq(labels, labels.T).float()
contrast_feature = torch.cat(torch.unbind(features, dim=0), dim=0)
anchor_feature = contrast_feature
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
mask = mask.repeat(batch_size, 1)
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * 2).view(-1, 1).cuda(),
0
)
mask = mask * logits_mask
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = - (self.temperature / 0.07) * mean_log_prob_pos
loss = loss.view(batch_size, 2).mean()
return loss
四、 互信息最大化损失(Mutual Information Maximization Loss)
这种损失用于最大化全局表示和局部表示之间的互信息,常用于图像或图数据。
1、公式
\[ L = - \frac{1}{N} \sum_{i=1}^{N} \left[ \log \frac{\exp(f(x_i) \cdot f(g_i))}{\sum_{j=1}^{N} \exp(f(x_i) \cdot f(g_j))} \right] \]
其中:
- \( f(x_i) \) 是样本 \( x_i \) 的局部表示。
- \( f(g_i) \) 是样本 \( x_i \) 的全局表示。
2、代码实现(PyTorch)
import torch
import torch.nn.functional as F
class MutualInformationLoss(nn.Module):
def __init__(self):
super(MutualInformationLoss, self).__init__()
def forward(self, local_features, global_features):
batch_size = local_features.size(0)
scores = torch.matmul(local_features, global_features.T)
labels = torch.arange(batch_size).cuda()
loss = F.cross_entropy(scores, labels)
return loss
这些对比学习损失函数在不同的任务和数据集上有不同的效果,可以根据具体需求进行选择和调整。