Metric learning

一:减小类内距离,增大类间距离


class Metric_loss(nn.Module):
    def __init__(self,src_class):
        super(Metric_loss, self).__init__()

        self.n_class=src_class


    def forward(self, s_feature,s_labels):

        n, d = s_feature.shape

        # get labels


        # image number in each class
        ones = torch.ones_like(s_labels, dtype=torch.float)
        zeros = torch.zeros(self.n_class)

        zeros = zeros.cuda()

        s_n_classes = zeros.scatter_add(0, s_labels, ones)


        # image number cannot be 0, when calculating centroids
        ones = torch.ones_like(s_n_classes)
        s_n_classes = torch.max(s_n_classes, ones)


        # calculating centroids, sum and divide
        zeros = torch.zeros(self.n_class, d)

        zeros = zeros.cuda()
        s_sum_feature = zeros.scatter_add(0, torch.transpose(s_labels.repeat(d, 1), 1, 0), s_feature)

        s_centroid = torch.div(s_sum_feature, s_n_classes.view(self.n_class, 1))


        # calculating inter distance

        temp = torch.zeros((n, d)).cuda()

        for i in range(n):
            temp[i] = s_centroid[s_labels[i]]

        s_all_centroid=s_centroid.sum(axis=0)/self.n_class

        s_all_centroid=s_all_centroid.repeat(self.n_class, 1)


        # inter_loss = torch.norm(s_all_centroid- s_centroid, p=1, dim=0).max()
        #
        # intra_loss = torch.norm(temp-s_feature, p=1, dim=0).max()

        inter_loss = torch.norm(s_all_centroid- s_centroid, p=1, dim=0).sum()

        intra_loss = torch.norm(temp-s_feature, p=1, dim=0).sum()



        inter_loss = inter_loss/d
        intra_loss=intra_loss/d

        return inter_loss,intra_loss

二:三元组损失

selector = BatchHardTripletSelector()
anchor, pos, neg = selector(feature, src_label)
triplet_loss = TripletLoss(margin=1).cuda()
triplet = triplet_loss(anchor, pos, neg)

class TripletLoss(nn.Module):
    '''
    Compute normal triplet loss or soft margin triplet loss given triplets
    '''
    def __init__(self, margin = None):
        super(TripletLoss, self).__init__()
        self.margin = margin
        if self.margin is None:  # use soft-margin
            self.Loss = nn.SoftMarginLoss()
        else:
            self.Loss = nn.TripletMarginLoss(margin = margin, p = 2)

    def forward(self, anchor, pos, neg):
        if self.margin is None:
            num_samples = anchor.shape[0]
            y = t.ones((num_samples, 1)).view(-1)
            if anchor.is_cuda: y = y.cuda()
            ap_dist = t.norm(anchor - pos, 2, dim = 1).view(-1)
            an_dist = t.norm(anchor - neg, 2, dim = 1).view(-1)
            loss = self.Loss(an_dist - ap_dist, y)
        else:
            loss = self.Loss(anchor, pos, neg)

        return loss


class BatchHardTripletSelector(object):
    '''
    a selector to generate hard batch embeddings from the embedded batch
    '''
    def __init__(self, *args, **kwargs):
        super(BatchHardTripletSelector, self).__init__()

    def __call__(self, embeds, labels):
        dist_mtx = pdist_torch(embeds, embeds).detach().cpu().numpy()# 计算距离
        labels = labels.contiguous().cpu().numpy().reshape((-1, 1))# 断开连接,深拷贝
        num = labels.shape[0]
        dia_inds = np.diag_indices(num)#返回对角线索引
        lb_eqs = labels == labels.T
        lb_eqs[dia_inds] = False
        dist_same = dist_mtx.copy()
        dist_same[lb_eqs == False] = -np.inf #负正无穷大的浮点表示
        pos_idxs = np.argmax(dist_same, axis = 1)
        dist_diff = dist_mtx.copy()
        lb_eqs[dia_inds] = True
        dist_diff[lb_eqs == True] = np.inf
        neg_idxs = np.argmin(dist_diff, axis = 1)
        pos = embeds[pos_idxs].contiguous().view(num, -1)
        neg = embeds[neg_idxs].contiguous().view(num, -1)
        return embeds, pos, neg

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值