详解常用的对比学习损失

对比学习损失函数用于在无监督或半监督的情况下学习数据表示,使得相似的数据样本在表示空间中更加接近,而不相似的样本更远离。以下是几种常见的对比学习损失函数及其详细说明:

一、对比损失(Contrastive Loss)

对比损失用于使得正样本对(相似样本对)在表示空间中接近,而负样本对(不相似样本对)远离。

1、公式

\[ L = \frac{1}{2N} \sum_{i=1}^{N} \left( y_i \cdot D_i^2 + (1 - y_i) \cdot \max(margin - D_i, 0)^2 \right) \]

其中:
\( y_i \) 是标签,1 表示正样本对,0 表示负样本对。
\( D_i \) 是样本对的欧氏距离。
\( margin \) 是一个超参数,表示负样本对之间的最小距离。

2、代码实现(PyTorch)

import torch
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                          label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss

二、三元组损失(Triplet Loss)

三元组损失用于训练模型使得锚点样本(Anchor)和正样本(Positive)之间的距离小于锚点样本和负样本(Negative)之间的距离。

1、公式

\[ L = \sum_{i=1}^{N} \left[ \|f(x_i^a) - f(x_i^p)\|_2^2 - \|f(x_i^a) - f(x_i^n)\|_2^2 + \alpha \right]_+ \]

其中:
- \( x_i^a \) 是锚点样本。
- \( x_i^p \) 是正样本。
- \( x_i^n \) 是负样本。
- \( \alpha \) 是一个超参数,表示正负样本对之间的最小距离差。

2、代码实现(PyTorch)

import torch
import torch.nn.functional as F

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_distance = F.pairwise_distance(anchor, positive)
        neg_distance = F.pairwise_distance(anchor, negative)
        loss = torch.mean(F.relu(pos_distance - neg_distance + self.margin))
        return loss

三、信息论对比损失(InfoNCE Loss)

InfoNCE 损失常用于自监督学习,通过最大化正样本对之间的相似度,同时最小化正样本对和负样本对之间的相似度。

1、公式

\[ L = - \frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(f(x_i) \cdot f(x_i^+))}{\exp(f(x_i) \cdot f(x_i^+)) + \sum_{j=1}^{K} \exp(f(x_i) \cdot f(x_j^-))} \]

其中:
- \( f(x_i) \) 是样本 \( x_i \) 的表示。
- \( x_i^+ \) 是正样本。
- \( x_j^- \) 是负样本。
- \( K \) 是负样本的数量。

2、代码实现(PyTorch)

import torch
import torch.nn.functional as F

class InfoNCELoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(InfoNCELoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        batch_size = features.size(0)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float()

        contrast_feature = torch.cat(torch.unbind(features, dim=0), dim=0)
        anchor_feature = contrast_feature

        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)

        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        mask = mask.repeat(batch_size, 1)
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * 2).view(-1, 1).cuda(),
            0
        )
        mask = mask * logits_mask

        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        loss = - (self.temperature / 0.07) * mean_log_prob_pos
        loss = loss.view(batch_size, 2).mean()

        return loss

四、 互信息最大化损失(Mutual Information Maximization Loss)

这种损失用于最大化全局表示和局部表示之间的互信息,常用于图像或图数据。

1、公式

\[ L = - \frac{1}{N} \sum_{i=1}^{N} \left[ \log \frac{\exp(f(x_i) \cdot f(g_i))}{\sum_{j=1}^{N} \exp(f(x_i) \cdot f(g_j))} \right] \]

其中:
- \( f(x_i) \) 是样本 \( x_i \) 的局部表示。
- \( f(g_i) \) 是样本 \( x_i \) 的全局表示。

2、代码实现(PyTorch)

import torch
import torch.nn.functional as F

class MutualInformationLoss(nn.Module):
    def __init__(self):
        super(MutualInformationLoss, self).__init__()

    def forward(self, local_features, global_features):
        batch_size = local_features.size(0)
        scores = torch.matmul(local_features, global_features.T)

        labels = torch.arange(batch_size).cuda()
        loss = F.cross_entropy(scores, labels)

        return loss

这些对比学习损失函数在不同的任务和数据集上有不同的效果,可以根据具体需求进行选择和调整。

虹膜识别孪生网络对比损失函数是一种用于训练孪生网络的损失函数,用于学习将同一主体的不同图像映射到相似的特征空间中,而将不同主体的图像映射到不同的特征空间中。该损失函数的目标是最小化同一主体图像对之间的距离,并最大化不同主体图像对之间的距离。 引用中提到了配对的对比损失作为唯一的监督信号,这是一种常见的用于训练孪生网络的对比损失函数。该损失函数通过比较同一主体的图像对和不同主体的图像对之间的距离来进行训练。具体而言,对于每个图像对,损失函数会计算它们在特征空间中的欧氏距离,并根据它们的标签(同一主体或不同主体)来调整损失。通过最小化同一主体图像对之间的距离和最大化不同主体图像对之间的距离,孪生网络可以学习到更具判别性的特征表示。 以下是一个示例代码,演示了如何使用虹膜识别孪生网络对比损失函数进行训练: ```python import tensorflow as tf # 定义孪生网络结构 def siamese_network(input_shape): input = tf.keras.Input(shape=input_shape) # 网络结构定义... return model # 定义对比损失函数 def contrastive_loss(y_true, y_pred): margin = 1.0 loss = tf.reduce_mean(y_true * tf.square(y_pred) + (1 - y_true) * tf.square(tf.maximum(margin - y_pred, 0))) return loss # 加载数据集 train_data = ... train_labels = ... # 创建孪生网络模型 input_shape = (64, 64, 3) model = siamese_network(input_shape) # 编译模型 model.compile(optimizer='adam', loss=contrastive_loss) # 训练模型 model.fit(train_data, train_labels, epochs=10, batch_size=32) # 使用训练好的模型进行预测 test_data = ... predictions = model.predict(test_data) # 相关问题:
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值