CenterLoss | 减小类间距离

1.centerloss原理

centerloss中心损失它仅仅用来减少类内的差异,而不能有效增大类间的差异性。下图中,图(a)表示softmax loss学习到的特征描述 。图(b)表示softmax loss + center loss 学习到的特征描述,他能把同一类的样本之间的距离拉近一些,使其相似性变大,尽量的往样本中心靠拢,但可以看出他没有把不同类样本之间的样本距离拉大。

centerloss的主要思路为:让每一类特征尽可能的在输出特征空间内聚集在一起。更直白的描述就是每一类的特征在特征空间中尽可能的聚集在某一个中心点附近。正常情况下,如果我们先验的知道了所有样本的GT中心点,那这个任务就好解决了,然而事实是我们无法预先获取类中心特征空间的分布。因此我们只能从训练的过程中动态的获取类中心特征,并对整体的训练过程产生约束。需要注意的是在训练的过程中,受限于GPU的显存等问题,我们不可能直接获取所有样本的特征中心,因此整个过程是基于batch进行的,而且当网络还未收敛的情况下,网络得到的特征中心也是不正确的。基于这两点,特征中心的确定势必是一个基于batch的动态过程。

2.中心点是如何维护的

接下来就详细讲一下这个动态过程,首先提出一个问题:中心点明明是不确定的,那如何让特征去聚集在这个不确定的特征中心点呢?

这要从centerloss的更新机制说起,从下面的两组公式可以看出,center中心点的更新方向是特征值和中心点的二范数,简单来说最终通过这种更新方式会使得某一类特征值对应的中心点被更新成与所有该类样本特征值的二范数和最小的位置,而这个位置我们可以广义的理解为所以特征的中心点位置。因此整体的centerloss是在边学习边找中心点的,最终中心点的确定和整体分类任务的收敛是同步进行的。

用知乎上比较概括性的话来讲就是:
center loss的原理主要是在softmax loss的基础上,通过对训练集的每个类别在特征空间分别维护一个类中心,在训练过程,增加样本经过网络映射后在特征空间与类中心的距离约束,从而兼顾了类内聚合与类间分离。

最终通过将centerloss和softmaxloss进行加权求和,实现整体的分类任务的学习。

centerloss的计算代码:

def forward(self, output_features, y_truth):
        """
        损失计算
        :param output_features: conv层输出的特征,  [b,c,h,w]
        :param y_truth:  标签值  [b,]
        :return:
        """
        batch_size = y_truth.size(0)
        output_features = output_features.view(batch_size, -1)
        assert output_features.size(-1) == self.feat_dim
        factor = self.scale / batch_size
        # return self.lamda * factor * self.lossfunc(output_features, y_truth, self.feature_centers))

        centers_batch = self.feature_centers.index_select(0, y_truth.long())  # [b,features_dim]
        diff = output_features - centers_batch
        loss = self.lamda * 0.5 * factor * (diff.pow(2).sum())
        #########
        return loss

center的更新代码:

# 改段代码需要注意的是backward返回值需要与对应的forward的输入参数一一对应。
class CenterlossFunc(Function):
    @staticmethod
    def forward(ctx, feature, label, centers, batch_size):
        ctx.save_for_backward(feature, label, centers, batch_size)
        centers_batch = centers.index_select(0, label.long())
        return (feature - centers_batch).pow(2).sum() / 2.0 / batch_size

    @staticmethod
    def backward(ctx, grad_output):
        feature, label, centers, batch_size = ctx.saved_tensors
        centers_batch = centers.index_select(0, label.long())
        diff = centers_batch - feature
        # init every iteration
        counts = centers.new_ones(centers.size(0))
        ones = centers.new_ones(label.size(0))
        grad_centers = centers.new_zeros(centers.size())

        counts = counts.scatter_add_(0, label.long(), ones)
        grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centers = grad_centers/counts.view(-1, 1)
        return - grad_output * diff / batch_size, None, grad_centers / batch_size, None

pytorch代码
https://www.cnblogs.com/dxscode/p/12059548.html
https://github.com/jxgu1016/MNIST_center_loss_pytorch/blob/master/CenterLoss.py

  • 1
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Center Loss是一种用于人脸识别和特征学习的损失函数。然而,Center Loss不能直接使用的原因是它需要考虑整个训练集,并在每次迭代中平均每个类的特征,这是低效的。因此,为了解决这个问题,我们需要对Center Loss进行一些改进。 一种解决方法是使用Mini-batch Center Loss,它在每个小批量数据上计算每个类别的特征中心。具体来说,对于每个类别,在小批量数据中统计该类别的特征,并计算出该类别的特征中心。然后,通过最小化特征与其对应类别的特征中心之间的距离,来更新特征中心。这样,我们可以在每个小批量数据上更新特征中心,而不需要考虑整个训练集。 另一种解决方法是使用在线更新的Center Loss。在线更新的Center Loss只在每个样本的前向传播过程中计算特征中心,并在反向传播过程中更新特征中心。这样,我们可以在每个样本上更新特征中心,而不需要在每次迭代中平均每个类的特征。 综上所述,为了解决Center Loss不能直接使用的问题,我们可以使用Mini-batch Center Loss或在线更新的Center Loss来改进Center Loss的效率和实用性。 #### 引用[.reference_title] - *1* *2* [CenterLoss原理详解(通透)](https://blog.csdn.net/weixin_54546190/article/details/124504683)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [CenterLoss | 减小类间距离](https://blog.csdn.net/qiu931110/article/details/106108936)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yuanCruise

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值