【笔记】Center Loss : 减小类内距离,增加类间距,使用矩阵计算高维空间距离集大成者

 

代码

首先上github代码:GitHub - KaiyangZhou/pytorch-center-loss: Pytorch implementation of Center Loss

class CenterLoss(nn.Module):
    """Center loss.
    
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
    def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu
 
        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
 
    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())
 
        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))
 
        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
 
        return loss

 

distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()

distmat.addmm_(1, -2, x, self.centers.t())

classes = torch.arange(self.num_classes).long()
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
mask = labels.eq(classes.expand(batch_size, self.num_classes))

 

附加:

        ECCV2016的文章《A Discriminative Feature Learning Approach for Deep Face Recognition》 主要为了进一步区分人脸。

code:https://github.com/ydwen/caffe-face
无法上github的可以在这里下载:https://download.csdn.net/download/duan19920101/12178594
Center Loss的Caffe实现:https://github.com/BOBrown/SSD-Centerloss

        Center Loss是通过将特征和特征中心的距离和softmax loss一同作为损失函数,使得类内距离更小,有点L1,L2正则化的意思。最关键是在训练时要使用2个Loss函数:Softmax Loss + lamda * Center Loss:

        和metric learning(度量学习)的想法一致,希望同类样本之间紧凑,不同类样本之间分散。现有的CNN最常用的softmax损失函数来训练网络,得到的深度特征通常具有比较强的区分性,也就是比较强的类间判别力。关于softmax的类内判别力,直接看图:

        上面给的是mnist的最后一层特征在二维空间的一个分布情况,可以看到类间是可分的,但类内存在的差距还是比较大的,在某种程度上类内间距大于类间的间距。对于像人脸这种复杂分布的数据,我们通常不仅希望数据在特征空间不仅是类间可分,更重要的是类内紧凑(分类任务中类别较多时均存在这个问题)。因为同一个人的类内变化很可能会大于类间的变化,只有保持类内紧凑,我们才能对那些类内大变化的样本有一个更加鲁棒的判定结果。也就是学习一种discriminative的特征。

下图就是我们希望达到的一种效果:

        考虑保持softmax loss的类间判别力,提出center loss,center loss就是为了约束类内紧凑的条件。相比于传统的CNN,仅改变了原有的损失函数,易于训练和优化网络。

        下面公式中log函数的输入就是softmax的结果(是概率),而Ls表示的是softmax loss的结果(是损失)。wx+b是全连接层的输出,因此log的输入就表示xi属于类别yi的概率。

Center Loss

       先看看center loss的公式LC。cyi表示第yi个类别的特征中心,xi表示全连接层之前的特征。实际使用的时候,m表示mini-batch的大小。因此这个公式就是希望一个batch中的每个样本的feature离feature 的中心的距离的平方和要越小越好,也就是类内距离要越小越好。
 

 

img1 label1
img2 label2
img3 label3
...
...
...

 

参考文章:
https://blog.csdn.net/yang_502/article/details/72792786
https://blog.csdn.net/u014380165/article/details/76946339
https://blog.csdn.net/sinat_33486980/article/details/101214447
https://blog.csdn.net/u011808673/article/details/81050616

  • 4
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值