center loss框架
从网络的的框架来看,center loss的主要工作是下图中的“Discriminative Features”。
普通的网络框架,在反向传播的过程中,根据类别标签,会将不同的类别划分开。如“Separable Features”所示,一开始两种颜色是混杂的,通过改变网络参数,让不同颜色能被分类器分开,就达到了目的。而这个过程中,只对不同类有要求,同一类没有进行约束。
center loss则是让类内的输出结果更加集中。
为了展示实际的效果,作者在mnist上进行了测试,下图是softmax分类器前面增加的一层的参数,其维度为2,这样就可以进行可视化的显示。
X
是上一层的输出,维度为800(根据论文计算得到),
在没有采用center loss时,不同类别的输出图像是一种花瓣,其特点是同一类的方差较大。可以找到分界线将不同类别区分开,虽然花瓣外尖端与其他类间距很大,花瓣中心的区分很小,很容易造成错误,如橘色区域,红线表示分类线。
如何让同一类颜色更集中呢?文中采用了center loss:
很简单,每个将输出点与这类中心点的距离累加作为损失。
回想方差公式:
是不是很类似?降低center loss其实也可以看作是降低同类的方差。
实现
推荐EncodeTS/TensorFlow_Center_Loss的代码,使用TensorFlow实现,且有详细的中文注释。
center loss流程大致为:
- 初始化权重中心
centers
,形状为[num_classes, len_features],中心值为0 - 在一次iteration中,获取mini-batch中每一个样本对应的中心值,
centers_batch
,形状为[batch_size, feature_length](使用tf.gather
技巧) - 计算loss,特征与中心features - centers_batch的l2范数
- 根据论文公式(3)(4)更新权重中心:
在一个mini-batch中,某一类j 出现了 n 次,分解来看:
- 属于该类的第
i 个样本与中心距离 cj−xi- 同理算出这个类出现的 n 次样本的距离,并汇总求和
- 除以
n+1
- 属于该类的第