# 一、简介:

center loss来自ECCV2016的一篇论文：A Discriminative Feature Learning Approach for Deep Face Recognition。

# 二、为什么要使用Center Loss:

. In most of the available CNNs, the softmax loss function is used as the supervision signal to train the deep model. In order to enhance the discriminative power of the deeply learned features, this paper proposes a new supervision signal, called center loss

the center loss simultaneously learns a center for deep features of each class and penalizes the distances between the deep features and their corresponding class centers

# 三、如何使学到的特征差异化更大——Center Loss：

Specifically, we learn a center (a vector with the same dimension as a feature) for deep features of each class.

The CNNs are trained under the joint supervision of the softmax loss and center loss, with a hyper parameter to balance the two supervision signals.

Softmax Loss （保证类之间的feature距离最大）与 Center Loss （保证类内的feature距离最小，更接近于类中心）

m是mini-batch、n是class。在Lc公式中有一个缺陷，就是Cyi是i这个样本对应的类别yi所属于的类中心C∈ Rd，d代表d维。

1、由整个训练集更新center改为mini-batch更改center

2、避免错误分类的样本的干扰，使用scalar α 来控制center的学习率

# 四、Center Loss的实现：

pytorch的使用者可以参看：https://github.com/jxgu1016/MNIST_center_loss_pytorch

（1）网络结构：

（2）如果任有不明白结合algorithm理解：

（3）Code：

class CenterLoss(nn.Module):
def __init__(self, num_classes, feat_dim, size_average=True):
super(CenterLoss, self).__init__()
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
self.centerlossfunc = CenterlossFunc.apply
self.feat_dim = feat_dim
self.size_average = size_average

def forward(self, label, feat):
batch_size = feat.size(0)
feat = feat.view(batch_size, -1)
# To check the dim of centers and features
if feat.size(1) != self.feat_dim:
raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size(1)))
loss = self.centerlossfunc(feat, label, self.centers)
loss /= (batch_size if self.size_average else 1)
return loss

class CenterlossFunc(Function):
@staticmethod
def forward(ctx, feature, label, centers):
ctx.save_for_backward(feature, label, centers)
centers_batch = centers.index_select(0, label.long())
return (feature - centers_batch).pow(2).sum() / 2.0

@staticmethod
feature, label, centers = ctx.saved_tensors
centers_batch = centers.index_select(0, label.long())
diff = centers_batch - feature
# init every iteration
counts = centers.new(centers.size(0)).fill_(1)
ones = centers.new(label.size(0)).fill_(1)

def main(test_cuda=False):
print('-'*80)
device = torch.device("cuda" if test_cuda else "cpu")
ct = CenterLoss(10,2).to(device)
y = torch.Tensor([0,0,2,1]).to(device)
print (list(ct.parameters()))
print(feat.grad)