为了打好基础,还是得深入理解代码啊啊啊,虽然看到代码都头疼,能咋整,还是一点一点来吧T-T,仅自留,毕竟我还是研0的小白。
先写loss部分吧,后面再慢慢写model、train、test部分~(师兄说主要看这四个部分就可啦~)
一、loss部分
首先就是一个难样本挖掘的三元组损失。
class OriTripletLoss(nn.Module):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
Args:
- margin (float): margin for triplet.
"""
def __init__(self, batch_size, margin=0.3):
super(OriTripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin) # 获得一个简单的距离triplet函数
def forward(self, inputs, targets):
"""
Args:
- inputs: feature matrix with shape (batch_size, feat_dim)
- targets: ground truth labels with shape (num_classes)
"""
n = inputs.size(0) #n即batch_size
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) # 每个数平方后, sum(保持行数n不变),再扩展成nxn维
dist = dist + dist.t() #dis[i][j]代表的是第i个特征与第j个特征的平方的和
dist.addmm_(1, -2, inputs, inputs.t()) # 然后减去2倍的 第i个特征*第j个特征 从而通过完全平方式得到 (a-b)^2
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability #开方,clamp做简单数值处理(为了数值的稳定性):小于min参数的dist元素值由min值取代。
#根号下不能为0,0开根号没有问题的,但是梯