Aligned TripletLoss

class AlignedTripletLoss(nn.Module):
    def __init__(self, margin=0.3):
        super().__init__()
        # margin就是三元组损失中的边界α
        self.margin = margin
        # 计算三元组损失使用的函数
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs, local_features, targets):
        """
        输入:
        1.全局特征张量inputs
        2.局部特征张量local_features
        3.真实行人IDtargets

        输出:
        1.全局特征损失global_loss
        2.局部特征损失,local_loss
        """
        # 获取批量
        n = inputs.size(0)

        # 将局部特征张量进行维度压缩
        local_features = local_features.squeeze()


        """
        计算图片之间的欧氏距离
        矩阵A,B欧氏距离等于√(A^2 + (B^T)^2 - 2A(B^T))
        """
        # 计算A^2
        distance = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        # 计算A^2 + (B^T)^2
        distance = distance + distance.t()
        # 计算A^2 + (B^T)^2 - 2A(B^T)
        #distance.addmm(1, -2, inputs, inputs.t())
        distance.addmm_(mat1 = inputs, mat2 = inputs.t(), beta = 1, alpha = -2)
        # 计算√(A^2 + (B^T)^2 - 2A(B^T))
        distance = distance.clamp(min=1e-12).sqrt()  # 该distance矩阵为对称矩阵

        # 获取正负样本对距离,使用难样本挖掘
        dist_ap, dist_an, p_inds, n_inds = hard_example_mining(distance, targets, return_inds=True)
        p_inds, n_inds = p_inds.long(), n_inds.long()
        print(p_inds)
        print(n_inds)
        # 根据难样本挖掘计算得到最小相似度正样本与最大相似度负样本索引,提取对应难样本的局部特征
        p_local_features = local_features[p_inds]
        n_local_features = local_features[n_inds]

        # 对难样本局部特征使用局部对齐最小距离算法计算样本对距离
        local_dist_ap = batch_local_dist(local_features, p_local_features)
        local_dist_an = batch_local_dist(local_features, n_local_features)

        # y指明ranking_loss前一个参数大于后一个参数
        y = torch.ones_like(dist_an)
        # 全局特征损失
        global_loss = self.ranking_loss(dist_an, dist_ap, y)
        # 局部特征损失
        local_loss = self.ranking_loss(local_dist_an, local_dist_ap, y)

        return global_loss, local_loss
if __name__ == '__main__':
    target = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8]
    target = torch.Tensor(target)
    features = torch.rand(32, 2048)
    local_features = torch.randn(32, 128, 8, 1)
    a = AlignedTripletLoss()
    g_loss, l_loss = a.forward(features, local_features, target)
    print(g_loss)
    print(l_loss)

输出结果:难样本挖掘 p_index,n_index(调用hard_example_mining()太麻了…[这个东西暂时没看懂])

tensor([ 2,  0,  0,  0,  7,  4,  5,  4,  9, 10,  9,  9, 15, 14, 13, 12, 17, 16,
        19, 18, 22, 23, 20, 21, 25, 26, 25, 25, 31, 31, 29, 29])
tensor([13, 13,  4, 22, 27, 27, 18, 16,  7,  4, 20, 17, 17,  7, 27, 10,  7, 11,
         6,  6,  7, 17, 24, 24, 22, 13, 20,  4, 23, 24,  4, 11])
tensor(0.9186)
tensor(1.1561)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值