算法
Center loss 能够直接对样本特征之间的距离进行约束。Center loss 添加的约束是,特征与同类别的平均特征的距离要足够小,这要求同类特征要接近它们的中心点,公式如下表示:
其中xi 表示第i个样本的提取特征,cyi表示样本i所对应的类别的所有样本特征的平均特征,或者说同类别样本特征的中心点,m表示样本个数。
如何计算cyi是一个难点,通过计算同一类别所有样本的特征,然后求平均值,这种方法是不切实际的,因为我们的训练样本非常庞大。作者另辟蹊径,使用mini-batch中的每个类别的平均特征近似不同类别所有样本的平均特征。这有点像BN中求feature map的均值和方差的思想。在梯度下降的每一次迭代过程中,cj的更新向量是:
其中I()是指示函数,当j 是类别yi时,函数返回1,否则返回0。分母的1是防止mini-batch中没有类别j的样本而导致分母为0。论文中设置了一个cj 的更新速率参数α,控制cj 的更新速度。
训练的总损失函数是:
Center Loss的算法如下: