最近项目中需要 center loss 提升模型的效果,但是 center loss 的实现就有点不确定,看了很多的博客,基本都是臆测,还是看源码来的实在。
下面就大致说下 center loss 的实现:
1、原理:
原理这块大家可以参考别人的博客,或者paper,这里就简单叙述下:让得到全连接层向量距离对应类别中心的距离最小
2、问题
类别中心是动态变化的么?如何进行变化?
(1)是每个epoch结束后使用所有的样本重新聚类计算得到样本中心么?
(2)在每个batch内计算动态变化得到聚类中心
当然是第二种方式,第一种方式太过于直白,最大的问题就是更新的太滞后了,基本上业界没有这样用的。
那么第二种方式该如何实现?每个batch内不一定包含所有的类别图像,维护一个参数矩阵?如何初始化?如何得到类别中心点(聚类还是求均值?)?
3、具体的实现
确实需要一个参数矩阵来维护并更新我们得到的聚类中心,常规能想到的方式就是自定义一个layer,然后再layey种定义参数矩阵等等,最终加入模型进行训练.
还有一种更为简洁的方式就是使用 Embedding 层的方式进行辅助训练,Embedding 层不仅仅可以实现一个维度的映射,而且最重要的是该层里面也有参数,是一个可以被训练的层,因此一切到这里就可以结束了ÿ