tensorflow实现center loss

实验的数据集是cifar10,在加上centor loss后准确率稍有下降

这里的labels不是one-hot编码的,labels的大小为Batch size * 1,作为索引使用
y_是one-hot编码的标签,通过tf.argmax(y_,1)变换为索引值

Cfg.centorloss_rate为center loss的rate
alpha为中心向量的学习率,设置为0.01时无法收敛,设为0.001时或更小值时可收敛

    labels = tf.argmax(y_ , 1)
    centorloss,centors,centers_update_op = center_loss(y,labels, alpha = 0.0001, num_classes = 10)
    with tf.name_scope('loss_value'):        
        loss=tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=y_, 
                                                            logits=y))+tf.add_n(tf.get_collection('losses')) + Cfg.centorloss_rate * centorloss

# coding: utf-8

import tensorflow as tf

def center_loss(features, labels, alpha, num_classes):
    with tf.variable_scope('center_loss',reuse = tf.AUTO_REUSE):  
        """获取center loss及center的更新op
        
        Arguments:
            features: Tensor,表征样本特征,一般使用某个fc层的输出,shape应该为[batch_size, feature_length].
            labels: Tensor,表征样本label,非one-hot编码,shape应为[batch_size].如cifar10,那么label为batch size大,值为0到10之间的数
            alpha: 0-1之间的数字,控制样本类别中心的学习率,细节参考原文.
            num_classes: 整数,表明总共有多少个类别,网络分类输出有多少个神经元这里就取多少.
        
        Return:
            loss: Tensor,可与softmax loss相加作为总的loss进行优化.
            centers: Tensor,存储样本中心值的Tensor,仅查看样本中心存储的具体数值时有用.
            centers_update_op: op,用于更新样本中心的op,在训练时需要同时运行该op,否则样本中心不会更新
        """
        # 获取特征的维数,例如256维
        len_features = features.get_shape()[1]
        # 建立一个Variable,shape为[num_classes, len_features],用于存储整个网络的样本中心,
        # 设置trainable=False是因为样本中心不是由梯度进行更新的
        centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32,
            initializer=tf.constant_initializer(0), trainable=False)
        # 将label展开为一维的,输入如果已经是一维的,则该动作其实无必要
    #    labels = tf.reshape(labels, [-1])
        
        # 根据样本label,获取mini-batch中每一个样本对应的中心值
        centers_batch = tf.gather(centers, labels)
#        print('centers_batch:',centers_batch.get_shape().as_list())
        # 计算loss
        loss = tf.nn.l2_loss(features - centers_batch)
        
        # 当前mini-batch的特征值与它们对应的中心值之间的差
        diff = centers_batch - features
        
        # 获取mini-batch中同一类别样本出现的次数,了解原理请参考原文公式(4)
        unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
        appear_times = tf.gather(unique_count, unique_idx)  ###用一个一维的索引数组,将张量中对应索引的向量提取出来
        appear_times = tf.reshape(appear_times, [-1, 1])
        
        diff = diff / tf.cast((1 + appear_times), tf.float32)
        diff = alpha * diff
        
        centers_update_op = tf.scatter_sub(centers, labels, diff)
        '''
        scatter_sub(ref,indices,updates)
        ref:一个可变的Tensor;必须是下列类型之一:float32,float64,int64,int32,uint8,uint16,int16,int8,complex64,complex128,qint8,quint8,qint32,half;应该来自一个Variable节点。
        indices:一个Tensor;必须是以下类型之一:int32,int64;进入ref的第一维度的一个索引的张量。
        updates:一个Tensor。必须与ref具有相同的类型。从ref中减去更新值的张量。
        '''
        ###centers_update_op  更新中心的op,需要run来更新中心
        return loss, centers, centers_update_op
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值