CLIP-ReID代码解读八——loss文件夹(triplet_loss.py)

triplet_loss.py 实现了深度学习中的三元组损失(Triplet Loss)和一些相关的辅助函数。以下是详细注释:

导入必要的库

from turtle import pd
import torch
from torch import nn

定义归一化函数

def normalize(x, axis=-1):
    """归一化到单位长度,沿指定维度进行。
    Args:
      x: PyTorch 变量
    Returns:
      x: PyTorch 变量,形状与输入相同
    """
    x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
    return x

定义欧氏距离函数

def euclidean_dist(x, y):
    """
    计算两个张量之间的欧氏距离。
    Args:
      x: PyTorch 变量,形状 [m, d]
      y: PyTorch 变量,形状 [n, d]
    Returns:
      dist: PyTorch 变量,形状 [m, n]
    """
    m, n = x.size(0), y.size(0)
    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)  # m x 1 -> m x n
    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()  # n x 1 -> n x m, 然后转置
    dist = xx + yy
    dist = dist - 2 * torch.matmul(x, y.t())
    dist = dist.clamp(min=1e-12).sqrt()  # 数值稳定性
    return dist

定义余弦距离函数

def cosine_dist(x, y):
    """
    计算两个张量之间的余弦距离。
    Args:
      x: PyTorch 变量,形状 [m, d]
      y: PyTorch 变量,形状 [n, d]
    Returns:
      dist: PyTorch 变量,形状 [m, n]
    """
    m, n = x.size(0), y.size(0)
    x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n)
    y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t()
    xy_intersection = torch.mm(x, y.t())
    dist = xy_intersection / (x_norm * y_norm)
    dist = (1. - dist) / 2
    return dist

定义难例挖掘函数

def hard_example_mining(dist_mat, labels, return_inds=False):
    """对每个锚点找到最难的正样本和负样本。
    Args:
      dist_mat: PyTorch 变量,样本之间的距离矩阵,形状 [N, N]
      labels: PyTorch 长张量,形状 [N]
      return_inds: 是否返回索引。若为 `False`,节省时间(?)
    Returns:
      dist_ap: PyTorch 变量,距离(锚点,正样本);形状 [N]
      dist_an: PyTorch 变量,距离(锚点,负样本);形状 [N]
      p_inds: PyTorch 长张量,形状 [N];
        选择的最难正样本的索引;0 <= p_inds[i] <= N - 1
      n_inds: PyTorch 长张量,形状 [N];
        选择的最难负样本的索引;0 <= n_inds[i] <= N - 1
    注意:仅考虑所有标签具有相同数量样本的情况,
      以便我们可以并行处理所有锚点。
    """
    assert len(dist_mat.size()) == 2
    assert dist_mat.size(0) == dist_mat.size(1)
    N = dist_mat.size(0)

    # 形状 [N, N]
    is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 
    is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())

    # `dist_ap` 表示距离(锚点,正样本)
    # `dist_ap` 和 `relative_p_inds` 的形状均为 [N, 1]
    dist_ap, relative_p_inds = torch.max(
        dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
    # `dist_an` 表示距离(锚点,负样本)
    # `dist_an` 和 `relative_n_inds` 的形状均为 [N, 1]
    dist_an, relative_n_inds = torch.min(
        dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
    # 形状 [N]
    dist_ap = dist_ap.squeeze(1)
    dist_an = dist_an.squeeze(1)

    if return_inds:
        # 形状 [N, N]
        ind = (labels.new().resize_as_(labels)
               .copy_(torch.arange(0, N).long())
               .unsqueeze(0).expand(N, N))
        # 形状 [N, 1]
        p_inds = torch.gather(
            ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
        n_inds = torch.gather(
            ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
        # 形状 [N]
        p_inds = p_inds.squeeze(1)
        n_inds = n_inds.squeeze(1)
        return dist_ap, dist_an, p_inds, n_inds

    return dist_ap, dist_an

定义三元组损失类

class TripletLoss(object):
    """
    使用更难样本挖掘的三元组损失,
    基于原始三元组损失修改。
    """

    def __init__(self, margin=None, hard_factor=0.0):
        self.margin = margin
        self.hard_factor = hard_factor
        if margin is not None:
            self.ranking_loss = nn.MarginRankingLoss(margin=margin)
        else:
            self.ranking_loss = nn.SoftMarginLoss()

    def __call__(self, global_feat, labels, normalize_feature=False):
        if normalize_feature:
            global_feat = normalize(global_feat, axis=-1)
        dist_mat = euclidean_dist(global_feat, global_feat)  # 计算全局特征的距离矩阵
        dist_ap, dist_an = hard_example_mining(dist_mat, labels)  # 挖掘难样本

        dist_ap *= (1.0 + self.hard_factor)
        dist_an *= (1.0 - self.hard_factor)

        y = dist_an.new().resize_as_(dist_an).fill_(1) 
        if self.margin is not None:
            loss = self.ranking_loss(dist_an, dist_ap, y)
        else:
            loss = self.ranking_loss(dist_an - dist_ap, y)
        return loss, dist_ap, dist_an

总结

  1. 归一化函数:将输入张量归一化到单位长度。
  2. 欧氏距离函数:计算两个张量之间的欧氏距离矩阵。
  3. 余弦距离函数:计算两个张量之间的余弦距离矩阵。
  4. 难例挖掘函数:找到每个样本中最难的正样本和负样本,返回它们的距离。
  5. 三元组损失类:使用更难样本挖掘的三元组损失函数,支持可选的边界(margin)和硬度因子(hard factor)。

这些函数和类共同实现了一个复杂的深度学习损失函数,用于改进特征学习的效果。

  • 12
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yiruzhao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值