caffe专题五——回归上

转自:https://blog.csdn.net/qq295456059/article/details/53142574

最近项目需要用到caffe来做关键点的回归,即通过caffe来训练一个网络,输出的结果不是简单地类别,而是一些坐标(浮点数)。

下面的这篇博文对caffe做回归有一个比较好的介绍:

http://www.cnblogs.com/frombeijingwithlove/p/5314042.html

这篇博文使用的是HDF5+python的方式。而我采用的是直接修改caffe的.cpp文件,并重新编译的方式,两种方式各有利弊,我个人认为理解并修改源码对进一步理解caffe很有帮助。当然配置了faster-rcnn或者SSD之后也可以做回归。

caffe本来就“擅长”于做分类任务,所以要拿caffe来做回归任务,就需要对caffe的源码做一些修改。修改的地方主要是下面两大部分:ps~这里可以借鉴作者的思路!

1、 制作lmdb文件相关的代码(即修改convert_imageset.cpp文件):image to Datum

2、 读取lmdb文件相关代码(即修改data_layer.cpp文件):Datum to Blob

根据这两大部分,我将博文分为上下两篇,本文为上篇,关于如何制作用于回归的lmdb文件。

首先,看一看用于分类的txt文件



后面带有多个归一化的坐标(上面的是我随便举的例子,没有实际的意义),实际应用中它们可能代表着某一个BoundingBox(边框回归)的坐标,或者是脸部一些关键点的坐标(上篇文章有介绍)。

下面我将一一列出需要修改代码的地方,带有//###标记的就是我修改的地方

首先是对tools/convert_imageset.cpp进行修改,复制tools/convert_imageset.cpp,并重新命名,这里姑且命名为convert_imageset_regression.cpp,依然放在tools文件夹下面。

[cpp]  view plain  copy
  1. // This program converts a set of images to a lmdb/leveldb by storing them  
  2. // as Datum proto buffers.  
  3. // Usage:  
  4. //   convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME  
  5. //  
  6. // where ROOTFOLDER is the root folder that holds all the images, and LISTFILE  
  7. // should be a list of files as well as their labels, in the format as  
  8. //   subfolder1/file1.JPEG 7  
  9. //   ....  
  10.   
  11. #include <algorithm>  
  12. #include <fstream>  // NOLINT(readability/streams)  
  13. #include <string>  
  14. #include <utility>  
  15. #include <vector>  
  16.   
  17. #include "boost/scoped_ptr.hpp"  
  18. #include "gflags/gflags.h"  
  19. #include "glog/logging.h"  
  20.   
  21. #include "caffe/proto/caffe.pb.h"  
  22. #include "caffe/util/db.hpp"  
  23. #include "caffe/util/format.hpp"  
  24. #include "caffe/util/io.hpp"  
  25. #include "caffe/util/rng.hpp"  
  26.   
  27. #include <boost/tokenizer.hpp> //### To use tokenizer  
  28. #include <iostream> //###  
  29.   
  30. using namespace caffe;  // NOLINT(build/namespaces)  
  31. using std::pair;  
  32. using boost::scoped_ptr;  
  33.   
  34. using namespace std;  //###  
  35.   
  36. DEFINE_bool(gray, false,  
  37.     "When this option is on, treat images as grayscale ones");  
  38. DEFINE_bool(shuffle, false,  
  39.     "Randomly shuffle the order of images and their labels");  
  40. DEFINE_string(backend, "lmdb",  
  41.         "The backend {lmdb, leveldb} for storing the result");  
  42. DEFINE_int32(resize_width, 0, "Width images are resized to");  
  43. DEFINE_int32(resize_height, 0, "Height images are resized to");  
  44. DEFINE_bool(check_size, false,  
  45.     "When this option is on, check that all the datum have the same size");  
  46. DEFINE_bool(encoded, false,  
  47.     "When this option is on, the encoded image will be save in datum");  
  48. DEFINE_string(encode_type, "",  
  49.     "Optional: What type should we encode the image as ('png','jpg',...).");  
  50.   
  51. int main(int argc, char** argv) {  
  52. #ifdef USE_OPENCV  
  53.   ::google::InitGoogleLogging(argv[0]);  
  54.   // Print output to stderr (while still logging)  
  55.   FLAGS_alsologtostderr = 1;  
  56.   
  57. #ifndef GFLAGS_GFLAGS_H_  
  58.   namespace gflags = google;  
  59. #endif  
  60.   
  61.   gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"  
  62.         "format used as input for Caffe.\n"  
  63.         "Usage:\n"  
  64.         "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"  
  65.         "The ImageNet dataset for the training demo is at\n"  
  66.         "    http://www.image-net.org/download-images\n");  
  67.   gflags::ParseCommandLineFlags(&argc, &argv, true);  
  68.   
  69.   if (argc < 4) {  
  70.     gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");  
  71.     return 1;  
  72.   }  
  73.   
  74.   const bool is_color = !FLAGS_gray;  
  75.   const bool check_size = FLAGS_check_size;  
  76.   const bool encoded = FLAGS_encoded;  
  77.   const string encode_type = FLAGS_encode_type;  
  78.     
  79.   std::ifstream infile(argv[2]);  
  80.   //std::vector<std::pair<std::string, int> > lines;  //###  
  81.   std::vector<std::pair<std::string, std::vector<float> > > lines;  
  82.   std::string line;  
  83.   //size_t pos;  
  84.   //int label;  //###  
  85.   std::vector<float> labels;  
  86.   
  87.   while (std::getline(infile, line)) {  
  88.     // pos = line.find_last_of(' ');  
  89.     // label = atoi(line.substr(pos + 1).c_str());  
  90.     // lines.push_back(std::make_pair(line.substr(0, pos), label));  
  91.     //###  
  92.     std::vector<std::string> tokens;  
  93.     boost::char_separator<char> sep(" ");  
  94.     boost::tokenizer<boost::char_separator<char> > tok(line, sep);  
  95.     tokens.clear();  
  96.     std::copy(tok.begin(), tok.end(), std::back_inserter(tokens));    
  97.   
  98.     for (int i = 1; i < tokens.size(); ++i)  
  99.     {  
  100.       labels.push_back(atof(tokens.at(i).c_str()));  
  101.     }  
  102.       
  103.     lines.push_back(std::make_pair(tokens.at(0), labels));  
  104.     //###To clear the vector labels  
  105.     labels.clear();  
  106.   }  
  107.   if (FLAGS_shuffle) {  
  108.     // randomly shuffle data  
  109.     LOG(INFO) << "Shuffling data";  
  110.     shuffle(lines.begin(), lines.end());  
  111.   }  
  112.   LOG(INFO) << "A total of " << lines.size() << " images.";  
  113.   
  114.   if (encode_type.size() && !encoded)  
  115.     LOG(INFO) << "encode_type specified, assuming encoded=true.";  
  116.   
  117.   int resize_height = std::max<int>(0, FLAGS_resize_height);  
  118.   int resize_width = std::max<int>(0, FLAGS_resize_width);  
  119.   
  120.   // Create new DB  
  121.   scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));  
  122.   db->Open(argv[3], db::NEW);  
  123.   scoped_ptr<db::Transaction> txn(db->NewTransaction());  
  124.   
  125.   // Storing to db  
  126.   std::string root_folder(argv[1]);  
  127.   Datum datum;  
  128.   int count = 0;  
  129.   int data_size = 0;  
  130.   bool data_size_initialized = false;  
  131.   
  132.   for (int line_id = 0; line_id < lines.size(); ++line_id) {  
  133.     bool status;  
  134.     std::string enc = encode_type;  
  135.     if (encoded && !enc.size()) {  
  136.       // Guess the encoding type from the file name  
  137.       string fn = lines[line_id].first;  
  138.       size_t p = fn.rfind('.');  
  139.       if ( p == fn.npos )  
  140.         LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";  
  141.       enc = fn.substr(p);  
  142.       std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);  
  143.     }  
  144.     status = ReadImageToDatum(root_folder + lines[line_id].first,   //###  
  145.         lines[line_id].second, resize_height, resize_width, is_color,  
  146.         enc, &datum);  
  147.     if (status == falsecontinue;  
  148.     if (check_size) {  
  149.       if (!data_size_initialized) {  
  150.         data_size = datum.channels() * datum.height() * datum.width();  
  151.         data_size_initialized = true;  
  152.       } else {  
  153.         const std::string& data = datum.data();  
  154.         CHECK_EQ(data.size(), data_size) << "Incorrect data field size "  
  155.             << data.size();  
  156.       }  
  157.     }  
  158.     // sequential  
  159.     string key_str = caffe::format_int(line_id, 8) + "_" + lines[line_id].first;  
  160.   
  161.     // Put in db  
  162.     string out;  
  163.     CHECK(datum.SerializeToString(&out));  
  164.     txn->Put(key_str, out);  
  165.   
  166.     if (++count % 1000 == 0) {  
  167.       // Commit db  
  168.       txn->Commit();  
  169.       txn.reset(db->NewTransaction());  
  170.       LOG(INFO) << "Processed " << count << " files.";  
  171.     }  
  172.   }  
  173.   // write the last batch  
  174.   if (count % 1000 != 0) {  
  175.     txn->Commit();  
  176.     LOG(INFO) << "Processed " << count << " files.";  
  177.   }  
  178. #else  
  179.   LOG(FATAL) << "This tool requires OpenCV; compile with USE_OPENCV.";  
  180. #endif  // USE_OPENCV  
  181.   return 0;  
  182. }  

上面的代码主要有两处进行了修改:一处是读取txt文件部分, 第二处是ReadImageToDatum函数。

首先,原来的label是一个int类型的变量,现在的label是多个float类型的变量,所以就有了下面的修改:


[cpp]  view plain  copy
  1. //std::vector<std::pair<std::string, int> > lines;  //###  
  2. std::vector<std::pair<std::string, std::vector<float> > > lines;  
  3. std::string line;  
  4. //size_t pos;  
  5. //int label;  //###  
  6. std::vector<float> labels;  

用float类型的vector来存放label,然后在读取txt文件的while循环中修改读取label部分的代码。

第一处修改完成之后,接下来需要对ReadImageToDatum函数进行修改,这个函数的作用是将图片的信息写入到Datum中,对Datum,Blob还不太了解的朋友可以参考下面这篇博文:http://www.cnblogs.com/yymn/articles/4479216.html,这里先暂时将Datum理解为一个存放图片信息(包括像素值和label)的数据结构,用于将图片写入到lmdb文件。

ReadImageToDatum函数在io.hpp中声明,我是使用sublime text3打开(open folder)caffe文件夹,直接选中ReadImageToDatum右键就可以“Goto Definition”。

在io.hpp文件中,原来的ReadImageToDatum函数是像下面这样声明的:


[cpp]  view plain  copy
  1. bool ReadImageToDatum(const string& filename, const int label,  
  2.     const int height, const int width, const bool is_color,  
  3.     const std::string & encoding, Datum* datum);  

我们可以不改动原来的函数声明(因为C++支持函数重载,这里指参数有所不同),而在它的下面接上:

[cpp]  view plain  copy
  1. bool ReadImageToDatum(const string& filename, const vector<float> labels,  
  2.     const int height, const int width, const bool is_color,  
  3.     const std::string & encoding, Datum* datum);  

容易注意到,我们参原来的参数

[cpp]  view plain  copy
  1. const int label  
修改成:

[cpp]  view plain  copy
  1. const vector<float> labels  
接着,我们需要在io.cpp函数中实现我们增加的重载函数:


[cpp]  view plain  copy
  1. bool ReadImageToDatum(const string& filename, const vector<float> labels,  
  2.     const int height, const int width, const bool is_color,  
  3.     const std::string & encoding, Datum* datum) {  
  4.   cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);  
  5.   if (cv_img.data) {  
  6.     // if (encoding.size()) {  
  7.     //   if ( (cv_img.channels() == 3) == is_color && !height && !width &&  
  8.     //       matchExt(filename, encoding) )  
  9.     //     return ReadFileToDatum(filename, label, datum);  
  10.     //   std::vector<uchar> buf;  
  11.     //   cv::imencode("."+encoding, cv_img, buf);  
  12.     //   datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),  
  13.     //                   buf.size()));  
  14.     //   datum->set_label(label);  
  15.     //   datum->set_encoded(true);  
  16.     //   return true;  
  17.     // }  
  18.                       
  19.     CVMatToDatum(cv_img, datum);  
  20.     //datum->set_label(label);  
  21.   
  22.     //###  
  23.     for (int i = 0; i < labels.size(); ++i)  
  24.     {  
  25.       datum->add_float_data(labels.at(i));  
  26.     }  
  27.   
  28.     return true;  
  29.   } else {  
  30.     return false;  
  31.   }  
  32. }  

在原来的ReadImageToDatum定义下面加上新的定义,(BTW:encoding部分对我暂时没有什么用,所以暂时注释掉)。这里使用:

[cpp]  view plain  copy
  1. datum->add_float_data(labels.at(i));  
将label写入到Datum中。

好了!经过上面的步骤,回到caffe目录下,重新make编译一下,就会在build/tools/文件夹下面生成一个convert_imageset_regression.bin可执行文件了。

再接下来制作lmdb的方法就跟分类任务一样了,需要制作我们的train.txt以及test.txt,以及将我们用于train和test的图片放到相应的文件夹下面,然后调用convert_imageset_regression.bin来制作lmdb即可,经过上面的代码修改,convert_imageset_regression.bin已经“懂得”如何将后面带有多个浮点类型的数字的txt转换成lmdb文件啦!


这里,可能有的朋友还会有一点疑问,

[cpp]  view plain  copy
  1. datum->add_float_data(labels.at(i));  
这个函数是怎么来的,第一次用的时候怎么会知道有这个函数?

这就得来看看caffe.proto文件了,里面关于Datum的代码如下:


[plain]  view plain  copy
  1. message Datum {  
  2.   optional int32 channels = 1;  
  3.   optional int32 height = 2;  
  4.   optional int32 width = 3;  
  5.   // the actual image data, in bytes  
  6.   optional bytes data = 4;  
  7.   optional int32 label = 5;  
  8.   // Optionally, the datum could also hold float data.  
  9.   repeated float float_data = 6;  
  10.   // If true data contains an encoded image that need to be decoded  
  11.   optional bool encoded = 7 [default = false];  
  12. }  

.proto文件是Google开发的一种协议接口,根据这个,可以自动生成caffe.pb.h和caffe.pb.cc文件。

其中,

[plain]  view plain  copy
  1. optional int32 label = 5;  
就是用于分类的。

而,

[plain]  view plain  copy
  1. repeated float float_data = 6;  

就是我们用来做回归的。

在caffe.pb.h文件中可以找到关于这部分自动生成的代码:


在上篇中,我们已经实现了lmdb的制作,实际上就是将训练和测试的图片的信息存放在Datum中,然后再序列化到lmdb文件中。

上篇完成了数据的准备工作,而要跑通整个实验,还需要在data_layer.cpp中做一些相应的修改。

data_layer.cpp中的函数实现了从lmdb中读取图片信息,先是反序列化成Datum,然后再放进Blob中。仔细想一下可以知道,因为原先caffe的data_layer.cpp的实现是针对分类的情况,所以读取label部分的代码并不适用于回归的情况。

所以本篇介绍data_layer.cpp需要修改的代码,以及训练的时候需要注意的一些细节。

下面是我修改后的data_layer.cpp文件,主要修改了两处地方:一是DataLayerSetup函数,二是load_batch函数。同上篇一样,有//###标记的就是我修改的地方


[cpp]  view plain  copy
  1. #ifdef USE_OPENCV  
  2. #include <opencv2/core/core.hpp>  
  3. #endif  // USE_OPENCV  
  4. #include <stdint.h>  
  5.   
  6. #include <vector>  
  7.   
  8. #include "caffe/data_transformer.hpp"  
  9. #include "caffe/layers/data_layer.hpp"  
  10. #include "caffe/util/benchmark.hpp"  
  11.   
  12. namespace caffe {  
  13.   
  14. template <typename Dtype>  
  15. DataLayer<Dtype>::DataLayer(const LayerParameter& param)  
  16.   : BasePrefetchingDataLayer<Dtype>(param),  
  17.     reader_(param) {  
  18. }  
  19.   
  20. template <typename Dtype>  
  21. DataLayer<Dtype>::~DataLayer() {  
  22.   this->StopInternalThread();  
  23. }  
  24.   
  25. template <typename Dtype>  
  26. void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  27.       const vector<Blob<Dtype>*>& top) {  
  28.   const int batch_size = this->layer_param_.data_param().batch_size();  
  29.   // Read a data point, and use it to initialize the top blob.  
  30.   Datum& datum = *(reader_.full().peek());  
  31.   
  32.   // Use data_transformer to infer the expected blob shape from datum.  
  33.   vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);  
  34.   this->transformed_data_.Reshape(top_shape);  
  35.   // Reshape top[0] and prefetch_data according to the batch_size.  
  36.   top_shape[0] = batch_size;  
  37.   top[0]->Reshape(top_shape);  
  38.   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  39.     this->prefetch_[i].data_.Reshape(top_shape);  
  40.   }  
  41.   LOG(INFO) << "output data size: " << top[0]->num() << ","  
  42.       << top[0]->channels() << "," << top[0]->height() << ","  
  43.       << top[0]->width();  
  44.   
  45.   // label  
  46.   //###  
  47.   // if (this->output_labels_) {  
  48.   //   vector<int> label_shape(1, batch_size);  
  49.   //   top[1]->Reshape(label_shape);  
  50.   //   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  51.   //     this->prefetch_[i].label_.Reshape(label_shape);  
  52.   //   }  
  53.   // }  
  54.   
  55.   //###  
  56.   int labelNum = 4;  
  57.   if (this->output_labels_) {  
  58.   
  59.     vector<int> label_shape;  
  60.     label_shape.push_back(batch_size);  
  61.     label_shape.push_back(labelNum);  
  62.     label_shape.push_back(1);  
  63.     label_shape.push_back(1);  
  64.     top[1]->Reshape(label_shape);  
  65.     for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  66.       this->prefetch_[i].label_.Reshape(label_shape);  
  67.     }  
  68.   }  
  69. }  
  70.   
  71. // This function is called on prefetch thread  
  72. template<typename Dtype>  
  73. void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {  
  74.   CPUTimer batch_timer;  
  75.   batch_timer.Start();  
  76.   double read_time = 0;  
  77.   double trans_time = 0;  
  78.   CPUTimer timer;  
  79.   CHECK(batch->data_.count());  
  80.   CHECK(this->transformed_data_.count());  
  81.   
  82.   // Reshape according to the first datum of each batch  
  83.   // on single input batches allows for inputs of varying dimension.  
  84.   const int batch_size = this->layer_param_.data_param().batch_size();  
  85.   Datum& datum = *(reader_.full().peek());  
  86.   // Use data_transformer to infer the expected blob shape from datum.  
  87.   vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);  
  88.   this->transformed_data_.Reshape(top_shape);  
  89.   // Reshape batch according to the batch_size.  
  90.   top_shape[0] = batch_size;  
  91.   batch->data_.Reshape(top_shape);  
  92.   
  93.   Dtype* top_data = batch->data_.mutable_cpu_data();  
  94.   Dtype* top_label = NULL;  // suppress warnings about uninitialized variables  
  95.   
  96.   if (this->output_labels_) {  
  97.     top_label = batch->label_.mutable_cpu_data();  
  98.   }  
  99.   for (int item_id = 0; item_id < batch_size; ++item_id) {  
  100.     timer.Start();  
  101.     // get a datum  
  102.     Datum& datum = *(reader_.full().pop("Waiting for data"));  
  103.     read_time += timer.MicroSeconds();  
  104.     timer.Start();  
  105.     // Apply data transformations (mirror, scale, crop...)  
  106.     int offset = batch->data_.offset(item_id);  
  107.     this->transformed_data_.set_cpu_data(top_data + offset);  
  108.     this->data_transformer_->Transform(datum, &(this->transformed_data_));  
  109.   
  110.     // Copy label.  
  111.     //###  
  112.     // if (this->output_labels_) {  
  113.     //   top_label[item_id] = datum.label();  
  114.     // }  
  115.   
  116.     //###  
  117.     int labelNum = 4;  
  118.     if (this->output_labels_) {  
  119.       for(int i=0;i<labelNum;i++){  
  120.         top_label[item_id*labelNum+i] = datum.float_data(i); //read float labels  
  121.       }  
  122.     }  
  123.   
  124.   
  125.     trans_time += timer.MicroSeconds();  
  126.   
  127.     reader_.free().push(const_cast<Datum*>(&datum));  
  128.   }  
  129.   timer.Stop();  
  130.   batch_timer.Stop();  
  131.   DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";  
  132.   DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";  
  133.   DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";  
  134. }  
  135.   
  136. INSTANTIATE_CLASS(DataLayer);  
  137. REGISTER_LAYER_CLASS(Data);  
  138.   
  139. }  // namespace caffe  


其中,第一处修改是:


[cpp]  view plain  copy
  1. //###  
  2. int labelNum = 4;<span style="white-space:pre;">    </span>//标签的数量,也就是txt中每一张图后面跟着的浮点数的数目  
  3. if (this->output_labels_) {  
  4.   
  5.   vector<int> label_shape;  
  6.   label_shape.push_back(batch_size);  
  7.   label_shape.push_back(labelNum);  
  8.   label_shape.push_back(1);  
  9.   label_shape.push_back(1);  
  10.   top[1]->Reshape(label_shape);  
  11.   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  12.     this->prefetch_[i].label_.Reshape(label_shape);  
  13.   }  
  14. }  

从DataLayerSetup函数传进来的参数可以看到,top是一个向量的地址,而向量的元素是Blob<Dtype>*。因为在caffe网络结构中,图片信息是分成两个Blob进行传递的,一个Blob记录图片的像素值,另外一个Blob记录图片的标签,这里的top[0],top[1]分别与之对应(所以实际上我们要修改的是top[1]相关的内容,top[0]相关的我们并不需要管)。

上面的代码是对top[1]的Reshape,push_back的四个值分别对应Blob的num,channels,height,width。因为top[1]对应的是标签,所以num设置为batch_size,channels设置为labelNum,height和width设置为1即可。这一步相当于是“塑造”一个适合我们数据label的Blob出来。

第二处修改的地方是:


[cpp]  view plain  copy
  1. //###  
  2. int labelNum = 4;  
  3. if (this->output_labels_) {  
  4.   for(int i=0;i<labelNum;i++){  
  5.     top_label[item_id*labelNum+i] = datum.float_data(i); //read float labels  
  6.   }  

这个地方是将datum中的label值赋值给top_label。

完成了上面两处修改之后,跟上篇一样,需要回到caffe目录下,重新执行make编译一下data_layer.cpp。编译完成之后,我们的修改就生效了!这样一来,convert_imageset_regression完成了将回归数据制作成lmdb的任务,而data_layer则完成了将用于回归的lmdb成功送入后续网络的任务。

那么,要成功运行caffe.bin进行训练,还需要注意一下下面的细节,主要是要注意网络配置文件(.prototxt):

1、最后一个全连接层的num_output应该与labelNum(即label的数目相等)

2、做分类任务的时候,一般是使用SoftmaxWithLoss类型的loss层,而在做回归任务的时候,一般是用EuclideanLoss类型的loss层,因为loss主要体现在网络最后一个全连接层的输出与ground true的欧氏距离

3、不使用Accuracy层,因为回归任务没有所谓的准确率

4、如果要在数据层做crop,scale,mirror等操作,应该先考虑一下变换之后你的label是否也需要变化,不能像分类任务那么“直接”地用

5、修改data_layer.cpp并重新编译之后,下次如果要进行分类任务,得记得改回去并重新编译(或者可以在github上git clone多个caffe下来,这样就不用来回修改)。


完成了上面所有的工作之后就可以对自己的数据进行训练和测试了。训练之后得到caffemodel,就可以拿来应用了。应用的时候,可以用caffe的Python接口或者是继续修改源码。


BTW:如果有朋友觉得本文中直接在data_layer.cpp中声明变量:

[cpp]  view plain  copy
  1. int labelNum = 4;  
不美观的话,恭喜你,可以看看我的下面这篇博文:

在 caffe 中 “个性化定制” data_param参数


-----------------------------------------------------------------------------------------------------------------------------

在这里,要感谢超哥的指点,使我看caffe代码的时候容易了许多。下面是他的博客以及github:

his CSDN

his github

后来他说在github上写博客比较高大上,所以转移到了这里: his github blog ,反正我现在还没到这种境界~~


发现了一篇用caffe做多标签分类的博文,改代码的思路很相似,可以互相借鉴:

http://blog.csdn.net/hubin232/article/details/50960201



[cpp]  view plain  copy
  1. // optional int32 label = 5;  
  2. inline bool has_label() const;  
  3. inline void clear_label();  
  4. static const int kLabelFieldNumber = 5;  
  5. inline ::google::protobuf::int32 label() const;  
  6. inline void set_label(::google::protobuf::int32 value);  
  7.   
  8. // repeated float float_data = 6;  
  9. inline int float_data_size() const;  
  10. inline void clear_float_data();  
  11. static const int kFloatDataFieldNumber = 6;  
  12. inline float float_data(int index) const;  
  13. inline void set_float_data(int index, float value);  
  14. inline void add_float_data(float value);  
  15. inline const ::google::protobuf::RepeatedField< float >&  
  16.     float_data() const;  
  17. inline ::google::protobuf::RepeatedField< float >*  
  18.     mutable_float_data();  

在这里就可以看到,关于操作label的一系列函数,如果我们不使用add_float_data,而是用set_float_data,也是可以的!

上篇就到这里吧。















评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值