Win10下为Caffe添加自定义层ConfusionLayer统计并输出混淆矩阵

本文通过添加自定义层ConfusionLayer和修改solver.cpp以及添加Solver参数实现了在test阶段输出混淆矩阵

1.打开VS2013为libcaffe工程添加Confusion层的头文件:confusion_layer.hpp

头文件confusion_layer.hpp是在accuracy_layer.hpp的基础上修改得到的,内容如下:

#ifndef CAFFE_CONFUSION_LAYER_HPP_
#define CAFFE_CONFUSION_LAYER_HPP_

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"


namespace caffe {

/**
 * @brief 计算混淆矩阵的各个元素
 * 当采用confusion层时,需要指定Sover参数topname_for_tfpn,使得测试时可以输出混淆矩阵
 * 该参数应与confusion层的top name一致,
 * 不设置该参数而采用confusion层,或者该参数与confusion的top name不一致,
 * 将不会输出预期的混淆矩阵,得到的结果是(混淆矩阵元素/迭代次数)
 * 为了得到准确的统计值,应保证测试迭代数*测试batch_size<=提供的测试样本数
 */
template <typename Dtype>
class ConfusionLayer : public Layer<Dtype> {
 public:

  explicit ConfusionLayer(const LayerParameter& param)
      : Layer<Dtype>(param) {}
  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);

  virtual inline const char* type() const { return "Confusion"; }
  virtual inline int ExactNumBottomBlobs() const { return 2; }

  // 只允许有一个top blob,包含混淆矩阵元素
  virtual inline int MinTopBlobs() const { return 1; }
  virtual inline int MaxTopBlos() const { return 1; }

 protected:
    
  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);

  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
    for (int i = 0; i < propagate_down.size(); ++i) {
      if (propagate_down[i]) { NOT_IMPLEMENTED; }
    }
  }

  int label_axis_, outer_num_, inner_num_;

  /// Whether to ignore instances with a certain label.
  bool has_ignore_label_;//从Accuracy层保留下来的参数
  /// The label indicating that an instance should be ignored.
  int ignore_label_;//从Accuracy层保留下来的参数
  /// Keeps counts of the number of samples per class.
  Blob<Dtype> nums_buffer_;//从Accuracy层保留下来的参数
};

}  // namespace caffe

#endif  // CAFFE_ACCURACY_LAYER_HPP_


2.打开VS2013为libcaffe工程添加Confusion层的源文件:confusion_layer.cpp

源文件confusion_layer.cpp是在accuracy_layer.cpp的基础上修改得到的,内容如下:

#include <functional>
#include <utility>
#include <vector>
#include<algorithm>

#include "caffe/layers/confusion_layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void ConfusionLayer<Dtype>::LayerSetUp(
  const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {//无需任何操作
}

template <typename Dtype>
void ConfusionLayer<Dtype>::Reshape(
  const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
  label_axis_ =
      bottom[0]->CanonicalAxisIndex(this->layer_param_.accuracy_param().axis());
  outer_num_ = bottom[0]->count(0, label_axis_);
  inner_num_ = bottom[0]->count(label_axis_ + 1);
  CHECK_EQ(outer_num_ * inner_num_, bottom[1]->count())
      << "Number of labels must match number of predictions; "
      << "e.g., if label axis == 1 and prediction shape is (N, C, H, W), "
      << "label count (number of labels) must be N*H*W, "
      << "with integer values in {0, 1, ..., C-1}.";
  const int num_labels = bottom[0]->shape(label_axis_);///全连接层的top blob第二维度大小为output_num,即类别数
  vector<int> top_shape(1, num_labels*num_labels);  //混淆矩阵只有一个维度,大小为类别数的平方
  top[0]->Reshape(top_shape);
}

template <typename Dtype>
void ConfusionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
    const vector<Blob<Dtype>*>& top) {
  const Dtype* bottom_data = bottom[0]->cpu_data();//获取输入的数据和标签值
  const Dtype* bottom_label = bottom[1]->cpu_data();
  const int dim = bottom[0]->count() / outer_num_;
  const int num_labels = bottom[0]->shape(label_axis_);//类别数

  //每次迭代,应当首先对top blob的数据清零,保证从0开始累加
  for (int i = 0; i < num_labels*num_labels; ++i){
	  top[0]->mutable_cpu_data()[i] = 0;
  }
  for (int i = 0; i < outer_num_; ++i) {//针对一个batch的所有样本的循环
    for (int j = 0; j < inner_num_; ++j) {
      const int label_value =
          static_cast<int>(bottom_label[i * inner_num_ + j]);
      if (has_ignore_label_ && label_value == ignore_label_) {
        continue;
      }
      if (top.size() > 1) ++nums_buffer_.mutable_cpu_data()[label_value];
      DCHECK_GE(label_value, 0);
      DCHECK_LT(label_value, num_labels);
	  //开始检测当前样本的预测值和label值,并累加混淆矩阵相应的元素
	  std::vector<Dtype> bottom_data_vec;//存储全连接层的输出向量
	  for (int k = 0; k < num_labels; ++k) {
		  bottom_data_vec.push_back(bottom_data[i * dim + k * inner_num_ + j]);//其实就是[i * dim + k ],因inner_num_=1
	  }
	  auto max_var = std::max_element(bottom_data_vec.begin(), bottom_data_vec.end());//求最大值所在位置
	  const int predicted_label_value = std::distance(bottom_data_vec.begin(), max_var);//最大值的索引就是预测label值
	  (top[0]->mutable_cpu_data()[num_labels*label_value + predicted_label_value])++;//对应元素+1,即label->pre_label
    }
  }
}

INSTANTIATE_CLASS(ConfusionLayer);
REGISTER_LAYER_CLASS(Confusion);

}  // namespace caffe
3.修改caffe.proto定义新加的ConfusionLayer

在caffe的目录中找到caffe-master\src\caffe\proto\caffe.proto,打开后搜索关键字"message LayerParameter",即定义Layer参数的地方,查看上一行的注释:

// LayerParameter next available layer-specific ID: 151 (last added: anything )

说明LayerParameter的下一个可用参数ID是151,那么就可以在下面的message LayerParameter里面加入下面的参数:

  //为Confusion层添加参数
  optional ConfusionParameter confusion_param = 151;

为了方便下次添加参数,可以把上面注释中的可用参数ID加1改为152:

// LayerParameter next available layer-specific ID: 152 (last added: ConfusionParameter)

然后在caffe.proto文件末尾添加Confusion层的参数区域(尽管并没有参数):

message ConfusionParameter {
}

由于修改了proto文件,需要重新编译caffe.proto,具体方法将在下面介绍。

至此,添加自定义层的操作已完成。

4.修改solver.cpp/hpp以对Confusion层输出的blob作特殊处理

首先在solver.hpp中为class Solver增加成员函数TestOutPutTFPN,可以搜索成员函数Test,然后在它的下面声明TestOutPutTFPN:

void TestOutPutTFPN(const int test_net_id = 0);

然后在原solver.cpp的基础上增加了成员函数TestOutPutTFPN的定义,并修改了成员函数TestAll,修改后的solver.cpp内容如下:

#include <cstdio>

#include <map>
#include <string>
#include <vector>

//#include "caffe/util/bbox_util.hpp"
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"

namespace caffe {

template<typename Dtype>
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
  action_request_function_ = func;
}

template<typename Dtype>
SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
  if (action_request_function_) {
    // If the external request function has been set, call it.
    return action_request_function_();
  }
  return SolverAction::NONE;
}

template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
    : net_(), callbacks_(), root_solver_(root_solver),
      requested_early_exit_(false) {
  Init(param);
}

template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
    : net_(), callbacks_(), root_solver_(root_solver),
      requested_early_exit_(false) {
  SolverParameter param;
  ReadSolverParamsFromTextFileOrDie(param_file, &param);
  Init(param);
}

template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
  CHECK(Caffe::root_solver() || root_solver_)
      << "root_solver_ needs to be set for all non-root solvers";
  LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
    << std::endl << param.DebugString();
  param_ = param;
  CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
  CheckSnapshotWritePermissions();
  if (Caffe::root_solver() && param_.random_seed() >= 0) {
    Caffe::set_random_seed(param_.random_seed());
  }
  // Scaffolding code
  InitTrainNet();
  if (Caffe::root_solver()) {
    InitTestNets();
    LOG(INFO) << "Solver scaffolding done.";
  }
  iter_ = 0;
  current_step_ = 0;
}

template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
  const int num_train_nets = param_.has_net() + param_.has_net_param() +
      param_.has_train_net() + param_.has_train_net_param();
  const string& field_names = "net, net_param, train_net, train_net_param";
  CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
      << "using one of these fields: " << field_names;
  CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
      << "one of these fields specifying a train_net: " << field_names;
  NetParameter net_param;
  if (param_.has_train_net_param()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net specified in train_net_param.";
    net_param.CopyFrom(param_.train_net_param());
  } else if (param_.has_train_net()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net from train_net file: " << param_.train_net();
    ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
  }
  if (param_.has_net_param()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net specified in net_param.";
    net_param.CopyFrom(param_.net_param());
  }
  if (param_.has_net()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net from net file: " << param_.net();
    ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
  }
  // Set the correct NetState.  We start with the solver defaults (lowest
  // precedence); then, merge in any NetState specified by the net_param itself;
  // finally, merge in any NetState specified by the train_state (highest
  // precedence).
  NetState net_state;
  net_state.set_phase(TRAIN);
  net_state.MergeFrom(net_param.state());
  net_state.MergeFrom(param_.train_state());
  net_param.mutable_state()->CopyFrom(net_state);
  if (Caffe::root_solver()) {
    net_.reset(new Net<Dtype>(net_param));
  } else {
    net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
  }
}

template <typename Dtype>
void Solver<Dtype>::InitTestNets() {
  CHECK(Caffe::root_solver());
  const bool has_net_param = param_.has_net_param();
  const bool has_net_file = param_.has_net();
  const int num_generic_nets = has_net_param + has_net_file;
  CHECK_LE(num_generic_nets, 1)
      << "Both net_param and net_file may not be specified.";
  const int num_test_net_params = param_.test_net_param_size();
  const int num_test_net_files = param_.test_net_size();
  const int num_test_nets = num_test_net_params + num_test_net_files;
  if (num_generic_nets) {
      CHECK_GE(param_.test_iter_size(), num_test_nets)
          << "test_iter must be specified for each test network.";
  } else {
      CHECK_EQ(param_.test_iter_size(), num_test_nets)
          << "test_iter must be specified for each test network.";
  }
  // If we have a generic net (specified by net or net_param, rather than
  // test_net or test_net_param), we may have an unlimited number of actual
  // test networks -- the actual number is given by the number of remaining
  // test_iters after any test nets specified by test_net_param and/or test_net
  // are evaluated.
  const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
  const int num_test_net_instances = num_test_nets + num_generic_net_instances;
  if (param_.test_state_size()) {
    CHECK_EQ(param_.test_state_size(), num_test_net_instances)
        << "test_state must be unspecified or specified once per test net.";
  }
  if (num_test_net_instances) {
    CHECK_GT(param_.test_interval(), 0);
  }
  int test_net_id = 0;
  vector<string> sources(num_test_net_instances);
  vector<NetParameter> net_params(num_test_net_instances);
  for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
      sources[test_net_id] = "test_net_param";
      net_params[test_net_id].CopyFrom(param_.test_net_param(i));
  }
  for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
      sources[test_net_id] = "test_net file: " + param_.test_net(i);
      ReadNetParamsFromTextFileOrDie(param_.test_net(i),
          &net_params[test_net_id]);
  }
  const int remaining_test_nets = param_.test_iter_size() - test_net_id;
  if (has_net_param) {
    for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
      sources[test_net_id] = "net_param";
      net_params[test_net_id].CopyFrom(param_.net_param());
    }
  }
  if (has_net_file) {
    for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
      sources[test_net_id] = "net file: " + param_.net();
      ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
    }
  }
  test_nets_.resize(num_test_net_instances);
  for (int i = 0; i < num_test_net_instances; ++i) {
    // Set the correct NetState.  We start with the solver defaults (lowest
    // precedence); then, merge in any NetState specified by the net_param
    // itself; finally, merge in any NetState specified by the test_state
    // (highest precedence).
    NetState net_state;
    net_state.set_phase(TEST);
    net_state.MergeFrom(net_params[i].state());
    if (param_.test_state_size()) {
      net_state.MergeFrom(param_.test_state(i));
    }
    net_params[i].mutable_state()->CopyFrom(net_state);
    LOG(INFO)
        << "Creating test net (#" << i << ") specified by " << sources[i];
    if (Caffe::root_solver()) {
      test_nets_[i].reset(new Net<Dtype>(net_params[i]));
    } else {
      test_nets_[i].reset(new Net<Dtype>(net_params[i],
          root_solver_->test_nets_[i].get()));
    }
    test_nets_[i]->set_debug_info(param_.debug_info());
  }
}

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
  const int start_iter = iter_;
  const int stop_iter = iter_ + iters;
  int average_loss = this->param_.average_loss();
  losses_.clear();
  smoothed_loss_ = 0;

  while (iter_ < stop_iter) {//优化主循环,持续到函数尾部
    // zero-init the params
    net_->ClearParamDiffs();
    if (param_.test_interval() && iter_ % param_.test_interval() == 0//如果迭代数满足测试条件,就测试一轮
        && (iter_ > 0 || param_.test_initialization())
        && Caffe::root_solver()) {
      TestAll();
      if (requested_early_exit_) {
        // Break out of the while loop because stop was requested while testing.
        break;
      }
    }

    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_start();
    }
    const bool display = param_.display() && iter_ % param_.display() == 0;//判断是否到了display的时候,若display设为0,就不会显示
    net_->set_debug_info(display && param_.debug_info());
    // accumulate the loss and gradient
    Dtype loss = 0;
    for (int i = 0; i < param_.iter_size(); ++i) {
      loss += net_->ForwardBackward();//前向计算
    }
    loss /= param_.iter_size();
    // average the loss across iterations for smoothed reporting
    UpdateSmoothedLoss(loss, start_iter, average_loss);
    if (display) {
      LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
          << ", loss = " << smoothed_loss_;
      const vector<Blob<Dtype>*>& result = net_->output_blobs();
      int score_index = 0;
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();
        const string& output_name =
            net_->blob_names()[net_->output_blob_indices()[j]];
        const Dtype loss_weight =
            net_->blob_loss_weights()[net_->output_blob_indices()[j]];
        for (int k = 0; k < result[j]->count(); ++k) {
          ostringstream loss_msg_stream;
          if (loss_weight) {
            loss_msg_stream << " (* " << loss_weight
                            << " = " << loss_weight * result_vec[k] << " loss)";
          }
          LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
              << score_index++ << ": " << output_name << " = "
              << result_vec[k] << loss_msg_stream.str();
        }
      }
    }
    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_gradients_ready();
    }
    ApplyUpdate();//更新权重

    // Increment the internal iter_ counter -- its value should always indicate
    // the number of times the weights have been updated.
    ++iter_;

    SolverAction::Enum request = GetRequestedAction();

    // Save a snapshot if needed.
    if ((param_.snapshot()
         && iter_ % param_.snapshot() == 0
         && Caffe::root_solver()) ||
         (request == SolverAction::SNAPSHOT)) {
      Snapshot();
    }
    if (SolverAction::STOP == request) {
      requested_early_exit_ = true;
      // Break out of training loop.
      break;
    }
  }
}

//Solve函数的定义
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
  CHECK(Caffe::root_solver());
  LOG(INFO) << "Solving " << net_->name();
  LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

  // Initialize to false every time we start solving.
  requested_early_exit_ = false;

  if (resume_file) {
    LOG(INFO) << "Restoring previous solver status from " << resume_file;
    Restore(resume_file);
  }

  // For a network that is trained by the solver, no bottom or top vecs
  // should be given, and we will just provide dummy vecs.
  int start_iter = iter_;
  Step(param_.max_iter() - iter_);//训练过程
  // If we haven't already, save a snapshot after optimization, unless
  // overridden by setting snapshot_after_train := false
  if (param_.snapshot_after_train()
      && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
    Snapshot();
  }
  if (requested_early_exit_) {
    LOG(INFO) << "Optimization stopped early.";
    return;
  }
  // After the optimization is done, run an additional train and test pass to
  // display the train and test loss/outputs if appropriate (based on the
  // display and test_interval settings, respectively).  Unlike in the rest of
  // training, for the train net we only run a forward pass as we've already
  // updated the parameters "max_iter" times -- this final pass is only done to
  // display the loss, which is computed in the forward pass.
  if (param_.display() && iter_ % param_.display() == 0) {
    int average_loss = this->param_.average_loss();
    Dtype loss;
    net_->Forward(&loss);

    UpdateSmoothedLoss(loss, start_iter, average_loss);

    LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
  }
  if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
    TestAll();
  }
  LOG(INFO) << "Optimization Done.";
}

template <typename Dtype>
void Solver<Dtype>::TestAll() {
  for (int test_net_id = 0;
       test_net_id < test_nets_.size() && !requested_early_exit_;
       ++test_net_id) {
	  if (param_.topname_for_tfpn() == ""){//新增参数SolverPrameter:topnamefortfpn="",决定是否输出混淆矩阵
		  Test(test_net_id);
	  }
	  else{
		  TestOutPutTFPN(test_net_id);
	  }
	   
  }
}

template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
  CHECK(Caffe::root_solver());
  LOG(INFO) << "Iteration " << iter_
            << ", Testing net (#" << test_net_id << ")";
  CHECK_NOTNULL(test_nets_[test_net_id].get())->
      ShareTrainedLayersWith(net_.get());
  vector<Dtype> test_score;///
  vector<int> test_score_output_id;///
  const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
  Dtype loss = 0;
  for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
    SolverAction::Enum request = GetRequestedAction();
    // Check to see if stoppage of testing/training has been requested.
    while (request != SolverAction::NONE) {
        if (SolverAction::SNAPSHOT == request) {
          Snapshot();
        } else if (SolverAction::STOP == request) {
          requested_early_exit_ = true;
        }
        request = GetRequestedAction();
    }
    if (requested_early_exit_) {
      // break out of test loop.
      break;
    }

    Dtype iter_loss;
    const vector<Blob<Dtype>*>& result =//Test函数是基于result blob来计算日志内容的,它能显示什么内容取决于能给它什么blob
        test_net->Forward(&iter_loss); 
    if (param_.test_compute_loss()) {
      loss += iter_loss;
    }
    if (i == 0) {///对第1次迭代做特殊处理
      for (int j = 0; j < result.size(); ++j) {//遍历每个结果blob
        const Dtype* result_vec = result[j]->cpu_data();//获得blob的数据指针
        for (int k = 0; k < result[j]->count(); ++k) {//遍历blob中的每个元素
          test_score.push_back(result_vec[k]);//给blob的每个元素开个内存空间,并把它们存进去
          test_score_output_id.push_back(j);//记录每个blob元素所属的blob ID
        }
      }
    } else {///
      int idx = 0;//上面为每个blob元素开辟的空间的索引
      for (int j = 0; j < result.size(); ++j) {//遍历每个结果blob
        const Dtype* result_vec = result[j]->cpu_data();//获得blob的数据指针
        for (int k = 0; k < result[j]->count(); ++k) {//遍历blob中的每个元素
          test_score[idx++] += result_vec[k];//累加每次迭代的所有blob元素的值
        }
      }
    }
  }
  if (requested_early_exit_) {
    LOG(INFO)     << "Test interrupted.";
    return;
  }
  if (param_.test_compute_loss()) {
    loss /= param_.test_iter(test_net_id);
    LOG(INFO) << "Test loss: " << loss;
  }
  for (int i = 0; i < test_score.size(); ++i) {///该循环持续到函数尾部,遍历所有的结果元素
    const int output_blob_index =				
        test_net->output_blob_indices()[test_score_output_id[i]];
    const string& output_name = test_net->blob_names()[output_blob_index];//如果迭代次数*batch size超过测试样本量的话,某些样本的会被测试多次,其结果会被多次计入,得到的混淆矩阵不准确,
    const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];//为了保证准确,迭代次数*batch size不能超过所给测试样本量
    ostringstream loss_msg_stream;												//
    const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);//计算关于迭代次数的平均值
    if (loss_weight) {
      loss_msg_stream << " (* " << loss_weight
                      << " = " << loss_weight * mean_score << " loss)";
    }
    LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
              << mean_score << loss_msg_stream.str();
  }
}

template <typename Dtype>
void Solver<Dtype>::TestOutPutTFPN(const int test_net_id){
	CHECK(Caffe::root_solver());
	LOG(INFO) << "Iteration " << iter_
		<< ", Testing net (#" << test_net_id << ")";
	CHECK_NOTNULL(test_nets_[test_net_id].get())->
		ShareTrainedLayersWith(net_.get());
	vector<Dtype> test_score;///
	vector<int> test_score_output_id;///
	int class_num = 0;//记录类别数
	const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
	Dtype loss = 0;
	//记录结果
	for (int i = 0; i < param_.test_iter(test_net_id); ++i) {//这里是真正的测试迭代循环
		SolverAction::Enum request = GetRequestedAction();
		// Check to see if stoppage of testing/training has been requested.
		while (request != SolverAction::NONE) {
			if (SolverAction::SNAPSHOT == request) {
				Snapshot();
			}
			else if (SolverAction::STOP == request) {
				requested_early_exit_ = true;
			}
			request = GetRequestedAction();
		}
		if (requested_early_exit_) {
			// break out of test loop.
			break;
		}

		Dtype iter_loss;
		const vector<Blob<Dtype>*>& result =//Test函数是基于result blob来计算日志内容的,它能显示什么内容取决于能给它什么blob
			test_net->Forward(&iter_loss);//进行测试阶段的前向计算,获取测试net的输出(result)blob
		if (param_.test_compute_loss()) {
			loss += iter_loss;
		}
		if (i == 0) {///对第1次迭代做特殊处理
			for (int j = 0; j < result.size(); ++j) {
				//这里增加一部分代码用于获取类别数
				const int output_blob_index =
					test_net->output_blob_indices()[j];//检测第j个result blob是否来自confusion层
				const string& output_name = test_net->blob_names()[output_blob_index]; 
				if (output_name == param_.topname_for_tfpn()){//proto参数字符串的类型就是string
					double class2 = result[j]->count();//混淆矩阵元素个数,即类数的平方
					class_num = sqrt(class2);
				}

				const Dtype* result_vec = result[j]->cpu_data();//获得结果blob的数据指针
				for (int k = 0; k < result[j]->count(); ++k) {//遍历blob中的每个元素
					test_score.push_back(result_vec[k]);//给blob的每个元素开个内存空间,并把它们存进去
					test_score_output_id.push_back(j);//记录每个blob元素所属的blob ID
				}
			}
		}
		else {///
			int idx = 0;//上面为每个blob元素开辟的空间的索引
			for (int j = 0; j < result.size(); ++j) {//遍历每个结果blob
				const Dtype* result_vec = result[j]->cpu_data();//获得blob的数据指针
				for (int k = 0; k < result[j]->count(); ++k) {//遍历blob中的每个元素
					test_score[idx++] += result_vec[k];//累加每次迭代的所有blob元素的值
				}
			}
		}
	}
	if (requested_early_exit_) {
		LOG(INFO) << "Test interrupted.";
		return;
	}
	if (param_.test_compute_loss()) {
		loss /= param_.test_iter(test_net_id);
		LOG(INFO) << "Test loss: " << loss;
	}

	int k = 0;//记录混淆矩阵元素在数组中的索引
	for (int i = 0; i < test_score.size(); ++i) {///该循环持续到函数尾部,遍历所有的结果元素
		const int output_blob_index =				
			test_net->output_blob_indices()[test_score_output_id[i]];
		const string& output_name = test_net->blob_names()[output_blob_index];//如果迭代次数*batch size超过测试样本量的话,某些样本的会被测试多次,其结果会被多次计入,得到的混淆矩阵不准确,
		//筛选出存储混淆矩阵的blob的元素
		if (output_name == param_.topname_for_tfpn()){
			int num = test_score[i];//混淆矩阵的元素
			LOG(INFO) << "    Test net output #" << i << ": " 
				<< output_name << ": "<<  k / class_num << "->" << k%class_num
				<< " = " << num;//lable->f(x),实际标签->预测值
			k++;
			continue;
		}

		const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];//为了保证准确,迭代次数*batch size不能超过所给测试样本量
		ostringstream loss_msg_stream;												//
		const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);//计算关于迭代次数的平均值
		if (loss_weight) {
			loss_msg_stream << " (* " << loss_weight
				<< " = " << loss_weight * mean_score << " loss)";
		}
		LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
			<< mean_score << loss_msg_stream.str();
	}
}

template <typename Dtype>
void Solver<Dtype>::Snapshot() {
  CHECK(Caffe::root_solver());
  string model_filename;
  switch (param_.snapshot_format()) {
  case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
    model_filename = SnapshotToBinaryProto();
    break;
  case caffe::SolverParameter_SnapshotFormat_HDF5:
    model_filename = SnapshotToHDF5();
    break;
  default:
    LOG(FATAL) << "Unsupported snapshot format.";
  }

  SnapshotSolverState(model_filename);
}

template <typename Dtype>
void Solver<Dtype>::CheckSnapshotWritePermissions() {
  if (Caffe::root_solver() && param_.snapshot()) {
    CHECK(param_.has_snapshot_prefix())
        << "In solver params, snapshot is specified but snapshot_prefix is not";
    string probe_filename = SnapshotFilename(".tempfile");
    std::ofstream probe_ofs(probe_filename.c_str());
    if (probe_ofs.good()) {
      probe_ofs.close();
      std::remove(probe_filename.c_str());
    } else {
      LOG(FATAL) << "Cannot write to snapshot prefix '"
          << param_.snapshot_prefix() << "'.  Make sure "
          << "that the directory exists and is writeable.";
    }
  }
}

template <typename Dtype>
string Solver<Dtype>::SnapshotFilename(const string extension) {
  return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
    + extension;
}

template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() {
  string model_filename = SnapshotFilename(".caffemodel");
  LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
  NetParameter net_param;
  net_->ToProto(&net_param, param_.snapshot_diff());
  WriteProtoToBinaryFile(net_param, model_filename);
  return model_filename;
}

template <typename Dtype>
string Solver<Dtype>::SnapshotToHDF5() {
  string model_filename = SnapshotFilename(".caffemodel.h5");
  LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
  net_->ToHDF5(model_filename, param_.snapshot_diff());
  return model_filename;
}

template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
  CHECK(Caffe::root_solver());
  string state_filename(state_file);
  if (state_filename.size() >= 3 &&
      state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
    RestoreSolverStateFromHDF5(state_filename);
  } else {
    RestoreSolverStateFromBinaryProto(state_filename);
  }
}

template <typename Dtype>
void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
    int average_loss) {
  if (losses_.size() < average_loss) {
    losses_.push_back(loss);
    int size = losses_.size();
    smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
  } else {
    int idx = (iter_ - start_iter) % average_loss;
    smoothed_loss_ += (loss - losses_[idx]) / average_loss;
    losses_[idx] = loss;
  }
}

INSTANTIATE_CLASS(Solver);

}  // namespace caffe

5.修改caffe.proto添加Solver参数topname_for_tfpn

在caffe的目录中找到caffe-master\src\caffe\proto\caffe.proto,打开后搜索关键字"message SolverParameter",即定义Solver参数的地方,查看上一行的注释:

// SolverParameter next available ID: 41 (last added: anything)

说明SolverParameter的下一个可用参数ID是41,那么就可以在下面的message SolverParameter里面加入参数topname_for_tfpn:

//添加Solver.cpp中用于切换带有求混淆矩阵功能的Test函数的参数
  optional string topname_for_tfpn = 41 [default = ""];

然后为了方便下次添加参数,可以把上面注释中的可用参数ID加1改为42:

// SolverParameter next available ID: 42 (last added: topname_for_tfpn)
6.重新编译caffe.proto文件

由于修改了caffe.proto,需要重新编译生成新的caffe.pb.cc和caffe.pb.h。

首先下载Protocol Buffers v2.6.1,将解压得到的protoc.exe与caffe.proto放在同一文件夹中,然后再在其中创建一个bat文件,内容如下:

protoc.exe caffe.proto --cpp_out=.\

pause

保存后双击运行bat文件即可编译proto生成新的caffe.pb.cc和caffe.pb.h。

然后将分别将新的caffe.pb.h和caffe.pb.cc拷贝到caffe-master\include\caffe\proto和caffe-master\src\caffe\proto目录下替换原文件。

7.使用方法

做完上面的6步后,用VS2013重新编译Caffe项目,然后就可以为自己的网络配置Confusion层,分2步走:

  1. 在自己的train_test.prototxt文件的末尾添加layer:

    layer {
      name: "confusion"
      type: "Confusion"
      bottom: "ip"	#根据自己网络最后一层的top进行修改
      bottom: "label"
      top: "TFPN"	#应与参数topname_for_tfpn一致
      include {
        phase: TEST
      }
    }
    
  2. 在solver.prototxt文件中设置参数topname_for_tfpn:

    topname_for_tfpn: "TFPN"	#应与Confudion层的top名一致
    

之后就可以在训练日志中看到test阶段输出的混淆矩阵,以二分类为例:

I0112 16:33:26.690425 13844 solver.cpp:421] Iteration 1000, Testing net (#0)
I0112 16:33:28.491104 13844 solver.cpp:505]     Test net output #0: TFPN: 0->0 = 2892
I0112 16:33:28.491104 13844 solver.cpp:505]     Test net output #1: TFPN: 0->1 = 150
I0112 16:33:28.491104 13844 solver.cpp:505]     Test net output #2: TFPN: 1->0 = 325
I0112 16:33:28.491104 13844 solver.cpp:505]     Test net output #3: TFPN: 1->1 = 729
I0112 16:33:28.491104 13844 solver.cpp:519]     Test net output #4: accuracy = 0.884033
I0112 16:33:28.492105 13844 solver.cpp:519]     Test net output #5: loss = 0.321114 (* 1 = 0.321114 loss)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值