深度学习干货学习(1)——center loss

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Lucifer_zzq/article/details/81236174

在构建loss时pytorch常用的包中有最常见的MSE、cross entropy(logsoftmax+NLLLoss)、KL散度Loss、BCE、HingeLoss等等,详见:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/#loss-functions


这里主要讲解一种考虑类间距离的Center Loss:

一、简介:

center loss来自ECCV2016的一篇论文:A Discriminative Feature Learning Approach for Deep Face Recognition。 
论文链接:http://ydwen.github.io/papers/WenECCV16.pdf 

 

二、为什么要使用Center Loss:

. In most of the available CNNs, the softmax loss function is used as the supervision signal to train the deep model. In order to enhance the discriminative power of the deeply learned features, this paper proposes a new supervision signal, called center loss

the center loss simultaneously learns a center for deep features of each class and penalizes the distances between the deep features and their corresponding class centers

简单的来说,我们在做分类(无论是image、instance、pixel level)的时候,我们不光需要学得separable的特征,更想要这些特征是discriminative的,这就意味着我们需要在loss上做更多的约束。

仅仅使用softmax作为监督信号的输出处理就只能做到seperable而不是discriminative,如下图:

 

三、如何使学到的特征差异化更大——Center Loss:

Specifically, we learn a center (a vector with the same dimension as a feature) for deep features of each class.

The CNNs are trained under the joint supervision of the softmax loss and center loss, with a hyper parameter to balance the two supervision signals. 

融合Softmax Loss 与 Center Loss:

Softmax Loss (保证类之间的feature距离最大)与 Center Loss (保证类内的feature距离最小,更接近于类中心)

m是mini-batch、n是class。在Lc公式中有一个缺陷,就是Cyi是i这个样本对应的类别yi所属于的类中心C∈ Rd,d代表d维。

理想情况下,Cyi需要随着学到的feature变化而实时更新,也就是要在每一次迭代中用整个数据集的feature来算每个类的中心。

但这显然不现实,做以下两个修改:

1、由整个训练集更新center改为mini-batch更改center

2、避免错误分类的样本的干扰,使用scalar α 来控制center的学习率

因此求算梯度的公式如下:

即:当yi = j,也就是mini-batch中某一个sample是对应要更新的那一个类的center的时候就累加起来除以某类的个数+1。

最终loss联立起来如上图,λ用于平衡softmax loss与center loss,越大则区分度 越大,如下图效果:

 

四、Center Loss的实现:

在三种我们清楚了原理,保证分类情况下的intra-class loss最小。下面讲解如何在代码和结构中实现:

pytorch的使用者可以参看:https://github.com/jxgu1016/MNIST_center_loss_pytorch

(1)网络结构:

即在特征层输出(classification前最后一层)引入center loss:

(2)如果任有不明白结合algorithm理解:

(3)Code:

class CenterLoss(nn.Module):
    def __init__(self, num_classes, feat_dim, size_average=True):
        super(CenterLoss, self).__init__()
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.centerlossfunc = CenterlossFunc.apply
        self.feat_dim = feat_dim
        self.size_average = size_average

    def forward(self, label, feat):
        batch_size = feat.size(0)
        feat = feat.view(batch_size, -1)
        # To check the dim of centers and features
        if feat.size(1) != self.feat_dim:
            raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size(1)))
        loss = self.centerlossfunc(feat, label, self.centers)
        loss /= (batch_size if self.size_average else 1)
        return loss


class CenterlossFunc(Function):
    @staticmethod
    def forward(ctx, feature, label, centers):
        ctx.save_for_backward(feature, label, centers)
        centers_batch = centers.index_select(0, label.long())
        return (feature - centers_batch).pow(2).sum() / 2.0

    @staticmethod
    def backward(ctx, grad_output):
        feature, label, centers = ctx.saved_tensors
        centers_batch = centers.index_select(0, label.long())
        diff = centers_batch - feature
        # init every iteration
        counts = centers.new(centers.size(0)).fill_(1)
        ones = centers.new(label.size(0)).fill_(1)
        grad_centers = centers.new(centers.size()).fill_(0)

        counts = counts.scatter_add_(0, label.long(), ones)
        grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centers = grad_centers/counts.view(-1, 1)
        return - grad_output * diff, None, grad_centers


def main(test_cuda=False):
    print('-'*80)
    device = torch.device("cuda" if test_cuda else "cpu")
    ct = CenterLoss(10,2).to(device)
    y = torch.Tensor([0,0,2,1]).to(device)
    feat = torch.zeros(4,2).to(device).requires_grad_()
    print (list(ct.parameters()))
    print (ct.centers.grad)
    out = ct(y,feat)
    print(out.item())
    out.backward()
    print(ct.centers.grad)
    print(feat.grad)

 

五、扩展:

center loss 与 constrastive loss 以及 triplet loss的区别在原文中也有给出,center loss相对于contrastive和triplet loss的优点显然省去了复杂并且含糊的样本对构造过程,接下来会对triplet loss做一个梳理。

展开阅读全文

没有更多推荐了,返回首页