源码URL:
https://github.com/michuanhaohao/deep-person-reid/blob/master/losses.py
TriHard loss部分的源码注释:
class TripletLoss(nn.Module):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
Args:
margin (float): margin for triplet.
"""
def __init__(self, margin=0.3):
super(TripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin) # pytorch的Triplet loss 需要输入ap an margin 和 倍率y,
# 最后算Relu(ap - y*an + margin) ap是正样本间距,an是负样本间距
def forward(self, inputs, targets):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
targets: ground truth labels with shape (num_classes)
"""
# (a - b)^2 = a^2 - 2ab + b^2
n = inputs.size(0)
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) # power是逐元素点乘,sum起来就是取模的平方, keepdim保持维度不变,不要求和成一个数
dist = dist + dist.t() #.t()为转置,这里是实现了元素的a^2 + b^2
dist.addmm_(1, -2, inputs, inputs.t()) # 做的是如(a1, a2, a, b) -> a1*dist + a2*a*b
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability # 将输入input张量每个元素的夹紧到区间 [min,max][min,max],并返回结果到一个新张量,防止他等于0,再做一个开方
# For each anchor, find the hardest positive and negative
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) # dist[i][mask[i]].max()得到的tensor没有维度,需要加unsqueeze(0)得到一维向量
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
dist_ap = torch.cat(dist_ap) # ap最难样本dist
dist_an = torch.cat(dist_an) # an最难样本dist
# Compute ranking hinge loss
y = torch.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y) # margin在__init__中设置过了
return loss