本文借鉴网上已有的center loss tensorflow版本代码,记录自己在解读该代码时所遇到的知识点以及疑问。
由于python知识浅薄,很多地方不懂,不要见笑。
def get_center_loss(features,labels,alpha=0.5,num_class=10):
#处理数据集为MNIST,所以num_class为10
#get feature dimension 这是为了初始化centers 使用get_shape()来取得维度 但是不明白为什么后面跟着[1] 待研究
len_features = features.get_shape()[1]
#initailizer class center 用get_variable函数 定义centers
#因为center loss里的center是根据公式计算更新而不是梯度下降更新,所以属性trainable为false
centers = tf.get_variable('centers',[num_class,len_features],dtype=tf.flloat32,initializer = tf.constant_initializer(0),trainable = False)
#为了节省计算量,center loss的中心更新都是在mini batch内进行的,所以需要获得mini batch内的centers
#center的获得是通过tf.gather函数来获取。该函数以labels作为index,从以初始化的centers中抽取出minibatch的centers。(tf.gather(params,indexs),根据