【ReID】【代码注释】难样本三元组损失TriHard loss deep-person-reid/losses.py

源码URL:
https://github.com/michuanhaohao/deep-person-reid/blob/master/losses.py

TriHard loss部分的源码注释:

class TripletLoss(nn.Module):
    """Triplet loss with hard positive/negative mining.

    Reference:
    Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.

    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.

    Args:
        margin (float): margin for triplet.
    """
    def __init__(self, margin=0.3):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)  # pytorch的Triplet loss 需要输入ap an margin 和 倍率y,
                                                                 # 最后算Relu(ap - y*an + margin)  ap是正样本间距,an是负样本间距

    def forward(self, inputs, targets):
        """
        Args:
            inputs: feature matrix with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (num_classes)
        """
        # (a - b)^2 = a^2 - 2ab + b^2
        n = inputs.size(0)
        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)  # power是逐元素点乘,sum起来就是取模的平方, keepdim保持维度不变,不要求和成一个数
        dist = dist + dist.t()  #.t()为转置,这里是实现了元素的a^2 + b^2
        dist.addmm_(1, -2, inputs, inputs.t())  # 做的是如(a1, a2, a, b) -> a1*dist + a2*a*b
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability  # 将输入input张量每个元素的夹紧到区间 [min,max][min,max],并返回结果到一个新张量,防止他等于0,再做一个开方
        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))  # dist[i][mask[i]].max()得到的tensor没有维度,需要加unsqueeze(0)得到一维向量
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)  # ap最难样本dist
        dist_an = torch.cat(dist_an)  # an最难样本dist
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)  # margin在__init__中设置过了
        return loss
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

锥栗

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值