caffe源码解读(2)-center_loss_layer.cpp

  • center_loss公式定义

  • center_loss_layer.cpp源码解读

  • center loss,softmax loss在mnist数据集上的对比实验

定义

“Center Loss: simultaneously learning a center for deep features of each class and penalizing the distances between the deep features and their corresponding class centers. 参考论文: A Discriminative Feature Learning Approach for Deep Face Recognition。 关于对center loss的理解,可参考知乎回答链接

公式

(1) Forward Computation

Lc=12Ni=1N|xic|22(2) (2) L c = 1 2 N ∑ i = 1 N | x i − c | 2 2

(2) Backward Computation
Lcxi=xic(3) (3) ∂ L c ∂ x i = x i − c

(3) Update Equation
Δc=αNi=1N(xic)(4) (4) Δ c = − α N ∑ i = 1 N ( x i − c )

代码

(1) LayerSetUp

namespace caffe{
    template<typename Dtype>
    void CenterLossLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
        const vector<Blob<Dtype>*>& top){
        LossLayer<Dtype>::LayerSetUp(bottom, top);
        CHECK_EQ(bottom[0]->num(), bottom[1]->num()); 
         //2个bottom,feature = bottom[0]->cpu_data();label = bottom[1]->cpu_data()
         //1个top,top[0]->mutable_cpu_data()[0] = loss;
        int channels = bottom[0]->channels();
        int num = bottom[0]->num();
         //获取center loss层的参数
         //loss weight 即参数lambda,用来调节center loss占比
        alpha = this->layer_param_.center_loss_param().alpha();
        lossWeight = this->layer_param_.center_loss_param().loss_weight();
        clusterNum = this->layer_param_.center_loss_param().cluster_num();

        center_info_.Reshape(clusterNum, channels, 1, 1);
        center_loss_.Reshape(num, channels, 1, 1);
        center_update_count_.resize(clusterNum);
        //caffe_set对center_info_.mutable_cpu_data()初始化
        caffe_set(clusterNum * channels, Dtype(0.0), center_info_.mutable_cpu_data());
    }

(2) Forward前向传播

template<typename Dtype>
    void CenterLossLayer<Dtype>::Forward_cpu(
        const vector<Blob<Dtype>*> &bottom,
        const vector<Blob<Dtype>*> &top){
        //2个bottom输入
        const Dtype *feature = bottom[0]->cpu_data();
        const Dtype *label = bottom[1]->cpu_data();
        int num = bottom[0]->num();
        int channels = bottom[0]->channels();
        //初始化loss
        Dtype loss = 0;
        caffe_set(clusterNum * channels, Dtype(0.0), center_info_.mutable_cpu_diff());
        for(int i = 0; i < clusterNum; ++i){
            center_update_count_[i] = 1;
        }
        for(int i = 0; i < num; ++i){
            int targetLabel = label[i];
            //caffe_sub做减法:center_loss.mutable_cpu_data=feature-center_info_.cpu_data()
            //即公式中xi-c
            caffe_sub(channels, feature + i * channels,
            center_info_.cpu_data() + targetLabel * channels,
            center_loss_.mutable_cpu_data() + i * channels);
            // store the update loss and number
            caffe_add(channels, center_loss_.cpu_data() + i * channels,
            center_info_.cpu_diff() + targetLabel * channels,
            center_info_.mutable_cpu_diff() + targetLabel * channels);
         center_update_count_[targetLabel]++;
       //此处即按公式(1)计算center loss
       //并将loss作为top输出
            loss += caffe_cpu_dot(channels, center_loss_.cpu_data() + i * channels,
            center_loss_.cpu_data() + i * channels) * lossWeight / Dtype(2.0) / static_cast<Dtype>(num);
        }
        top[0]->mutable_cpu_data()[0] = loss;
        // update center loss.按公式(3)更新类中心:c
        for(int i = 0; i < clusterNum; ++i){
            Dtype scale = -alpha * lossWeight / Dtype(center_update_count_[i]);
            caffe_scal(channels, scale, center_info_.mutable_cpu_diff() + i * channels);
        }
        center_info_.Update();
    }

(3) Backward反向传播

    template<typename Dtype>
    void CenterLossLayer<Dtype>::Backward_cpu(
        const vector<Blob<Dtype>*> &top,
        const vector<bool> &propagate_down,
        const vector<Blob<Dtype>*> &bottom){
        int num = bottom[0]->num();
        int channels = bottom[0]->channels();
        //center_loss_.mutable_cpu_data()=feature-center_info_.cpu_data()
        //按公式(2)计算反向传播偏导
        caffe_scal(num * channels, lossWeight, center_loss_.mutable_cpu_data());
        Dtype *out = bottom[0]->mutable_cpu_diff();
        //center_loss_.cpu_data()拷贝到out中进行backward运算
        caffe_copy(num * channels, center_loss_.cpu_data(), out);

    }

实验

Github 上有开源的整个项目的代码[链接](https://github.com/wangwen39/center-loss),新手可以用来练手。特征可视化可直接参考caffe主页:http://nbviewer.jupyter.org/github/BVLC/caffe/blob/master/examples/siamese/mnist_siamese.ipynb
mnist数据集共有10个类别的手写体数字0-9,通过对比实验可以看出,center loss能够很好的使类类之间的距离增大,同时使类内更加聚拢,从而达到更好的分类准确度。
① softmax loss
softmax
② center loss + softmax loss

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值