如何使用Caffe模型和权值进行开发

可能看到网上有很多关于适用caffe模型和权值实现各种应用的程序,但是怎么实现的呢。下面以一个程序一步步讲讲:
先上程序:

Detector::Detector(const string& model_file,
                   const string& weights_file) {
   google::InitGoogleLogging("ssd");   
   google::SetCommandLineOption("GLOG_minloglevel", "2"); 

  #ifdef CPU_ONLY
    Caffe::set_mode(Caffe::CPU);
  #else
    Caffe::set_mode(Caffe::GPU);
  #endif

  net_.reset(new Net<float>(model_file, TEST));//1.(1)
  net_->CopyTrainedLayersFrom(weights_file);//1.(1)

  CHECK_EQ(net_->num_inputs(), 1) << "Network should have exactly one input.";
  CHECK_EQ(net_->num_outputs(), 1) << "Network should have exactly one output.";

  Blob<float>* input_layer = net_->input_blobs()[0];//1.(2)
  num_channels_ = input_layer->channels();//1.(2)
  CHECK(num_channels_ == 3 || num_channels_ == 1)
    << "Input layer should have 1 or 3 channels.";
  input_geometry_ = cv::Size(input_layer->width(), input_layer->height());//1.(2)
  google::ShutdownGoogleLogging();
}
std::vector<vector<float> > Detector::Detect(const cv::Mat& img) {
  Blob<float>* input_layer = net_->input_blobs()[0];//2.(1)
  input_layer->Reshape(1, num_channels_,
                       input_geometry_.height, input_geometry_.width);//2.(1)
  /* Forward dimension change to all layers. */
  net_->Reshape();//2.(2)
  std::vector<cv::Mat> input_channels;
  WrapInputLayer(&input_channels);//2.(3)
  Preprocess(img, &input_channels);//2.(4)
  net_->Forward();//2.(5)
  /* Copy the output layer to a std::vector */
  Blob<float>* result_blob = net_->output_blobs()[0];//2.(6)
  const float* result = result_blob->cpu_data();
  const int num_det = result_blob->height();
  vector<vector<float> > detections;
  for (int k = 0; k < num_det; ++k) {
    if (result[0] == -1) {
      // Skip invalid detection.
      result += 7;
      continue;
    }
    vector<float> detection(result, result + 7);
    detections.push_back(detection);
    result += 7;
  }
  return detections;
}

void Detector::WrapInputLayer(std::vector<cv::Mat>* input_channels) {
  Blob<float>* input_layer = net_->input_blobs()[0];

  int width = input_layer->width();
  int height = input_layer->height();
  float* input_data = input_layer->mutable_cpu_data();
  for (int i = 0; i < input_layer->channels(); ++i) {
    cv::Mat channel(height, width, CV_32FC1, input_data);
    input_channels->push_back(channel);
    input_data += width * height;
  }
}

void Detector::Preprocess(const cv::Mat& img,
                            std::vector<cv::Mat>* input_channels) {
  /* Convert the input image to the input image format of the network. */
  cv::Mat sample;
  if (img.channels() == 3 && num_channels_ == 1)
    cv::cvtColor(img, sample, cv::COLOR_BGR2GRAY);
  else if (img.channels() == 4 && num_channels_ == 1)
    cv::cvtColor(img, sample, cv::COLOR_BGRA2GRAY);
  else if (img.channels() == 4 && num_channels_ == 3)
    cv::cvtColor(img, sample, cv::COLOR_BGRA2BGR);
  else if (img.channels() == 1 && num_channels_ == 3)
    cv::cvtColor(img, sample, cv::COLOR_GRAY2BGR);
  else
    sample = img;

  cv::Mat sample_resized;
  if (sample.size() != input_geometry_)
    cv::resize(sample, sample_resized, input_geometry_);
  else
    sample_resized = sample;

  cv::Mat sample_float;
  if (num_channels_ == 3)
    sample_resized.convertTo(sample_float, CV_32FC3, 0.0078125,-127.5*0.0078125);
  else
    sample_resized.convertTo(sample_float, CV_32FC1, 0.0078125,-127.5*0.0078125);

  cv::split(sample_float, *input_channels);
  CHECK(reinterpret_cast<float*>(input_channels->at(0).data)
        == net_->input_blobs()[0]->cpu_data())
    << "Input channels are not wrapping the input layer of the network.";
}

一般的caffe的应用程序都会有这么几个函数,按照自己设计的网络的不同可能再需要增加一些其他的函数。
1、先看第一函数Detector:
(1)这里定义为了其构造函数,这里的构造函数主要实现对于训练好的caffe的模型和权值的加载任务。
(2)创建input_layer,获取输入到网络的图片的大小。
2、总的执行函数是Detect:
(1)创建输入输入对象Blob,初始化输入层。
(2)初始化整个网络。
(3)创建输入层的各个像素通道。
(4)填充各个像素通道。
(5)网络前向传播
(6)获取输出结果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值