MTCNN训练环境搭建(一)

原文链接:https://blog.csdn.net/wwww1244/article/details/81034045

多标签分类在工程上有很多应用,例如,输入一张图片,判断这个人的年龄、性别和是否配戴眼镜。这时,数据集的label文件应当具有这样的格式:

  • 000001.jpg 22 1 0
  • 000002.jpg 30 1 1
  • 000003.jpg 44 0 1
  • 000004.jpg 17 0 0

假定第一个数字表示年龄,第二个0/1表示女/男,第三个0/1表示不戴眼镜/戴眼镜。

同样地,回归问题在CNN中也有很多应用,例如,在物体检测领域(例如Faster R-CNN、SSD、Yolo等网络),输入一张图片,输出bounding box的平移缩放参数(当然,实际情况会比这复杂一点。。)。这时,数据集的label文件应当具有这样的格式:

  • 000001.jpg -0.01 -0.29 -0.05 0.02
  • 000002.jpg 0.25 -0.04 0.02 -0.07
  • 000003.jpg 0.23 0.05 -0.22 0.15
  • 000004.jpg 0.20 -0.12 0.02 -0.20

四个参数分别表示x、y、w、h的变化量。

(实际上,上面是从MTCNN人脸检测网络的训练数据中截取的bounding box回归部分,关于MTCNN训练数据格式下一篇会提到)

--------------------------------------------------------------------------------

caffe对单标签分类问题提供了很好的支持,但是对多标签分类/回归问题的支持却不是很好,用于生成lmdb数据集的可执行文件build/tools/convert_imageset只能接受单标签和整形数,显然,多标签分类/回归问题需要程序接受多标签和浮点型数,这就是我们要做的工作。

需要注意的是,即使用浮点型数读入分类标签,也不会影响分类结果。

我希望能够继续使用lmdb数据集,这就需要修改一些caffe源码,虽然修改过程复杂了点,但因为基本是一劳永逸的,所以还是值得一试的。

 

参考博客:用 caffe 做回归 (上)

这位dalao的博客基本上可以解决问题,我对他做过的修改进行了汇总,总的来说,共需要修改以下几个文件:

  • tools/convert_imageset.cpp:用于生成lmdb数据集。将该文件复制一份,重命名为tools/convert_imageset_multi.cpp并修改该文件。
  • include/caffe/util/io.hpp:为上面的程序提供新的函数
  • src/caffe/util/io.cpp
  • src/caffe/layers/data_layer.cpp:用于加载lmdb数据集
  • src/caffe/proto/caffe.proto:用于添加label_num参数

首先声明,我用的是BVLC版本的caffe(2018.7):https://github.com/BVLC/caffe

如果caffe版本不一样,这些文件很可能不能直接使用。接下来就是这些文件的修改版,其中,//###表示修改的部分,主要思路就是将int型数据转变为vector<float>型数据

1. tools/convert_imageset_multi.cpp:


 
 
  1. // This program converts a set of images to a lmdb/leveldb by storing them
  2. // as Datum proto buffers.
  3. // Usage:
  4. // convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
  5. //
  6. // where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
  7. // should be a list of files as well as their labels, in the format as
  8. // subfolder1/file1.JPEG 7
  9. // ....
  10. #include <algorithm>
  11. #include <fstream> // NOLINT(readability/streams)
  12. #include <string>
  13. #include <utility>
  14. #include <vector>
  15. #include "boost/scoped_ptr.hpp"
  16. #include "gflags/gflags.h"
  17. #include "glog/logging.h"
  18. #include "caffe/proto/caffe.pb.h"
  19. #include "caffe/util/db.hpp"
  20. #include "caffe/util/format.hpp"
  21. #include "caffe/util/io.hpp"
  22. #include "caffe/util/rng.hpp"
  23. #include <iostream> //###
  24. #include <boost/tokenizer.hpp> //###
  25. using namespace caffe; // NOLINT(build/namespaces)
  26. using std::pair;
  27. using boost::scoped_ptr;
  28. DEFINE_bool(gray, false,
  29. "When this option is on, treat images as grayscale ones");
  30. DEFINE_bool(shuffle, false,
  31. "Randomly shuffle the order of images and their labels");
  32. DEFINE_string(backend, "lmdb",
  33. "The backend {lmdb, leveldb} for storing the result");
  34. DEFINE_int32(resize_width, 0, "Width images are resized to");
  35. DEFINE_int32(resize_height, 0, "Height images are resized to");
  36. DEFINE_bool(check_size, false,
  37. "When this option is on, check that all the datum have the same size");
  38. DEFINE_bool(encoded, false,
  39. "When this option is on, the encoded image will be save in datum");
  40. DEFINE_string(encode_type, "",
  41. "Optional: What type should we encode the image as ('png','jpg',...).");
  42. int main(int argc, char** argv) {
  43. #ifdef USE_OPENCV
  44. ::google::InitGoogleLogging(argv[ 0]);
  45. // Print output to stderr (while still logging)
  46. FLAGS_alsologtostderr = 1;
  47. #ifndef GFLAGS_GFLAGS_H_
  48. namespace gflags = google;
  49. #endif
  50. gflags::SetUsageMessage( "Convert a set of images to the leveldb/lmdb\n"
  51. "format used as input for Caffe.\n"
  52. "Usage:\n"
  53. " convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
  54. "The ImageNet dataset for the training demo is at\n"
  55. " http://www.image-net.org/download-images\n");
  56. gflags::ParseCommandLineFlags(&argc, &argv, true);
  57. if (argc < 4) {
  58. gflags::ShowUsageWithFlagsRestrict(argv[ 0], "tools/convert_imageset");
  59. return 1;
  60. }
  61. const bool is_color = !FLAGS_gray;
  62. const bool check_size = FLAGS_check_size;
  63. const bool encoded = FLAGS_encoded;
  64. const string encode_type = FLAGS_encode_type;
  65. std:: ifstream infile(argv[2]);
  66. // std::vector<std::pair<std::string, int> > lines;
  67. std:: vector< std::pair< std:: string, std:: vector< float> > > lines; //### int -> vector<float>
  68. std:: string line;
  69. size_t pos;
  70. // int label; //###
  71. std:: vector< float> labels; //###
  72. while ( std::getline(infile, line)) {
  73. //###
  74. // pos = line.find_last_of(' ');
  75. // label = atoi(line.substr(pos + 1).c_str());
  76. // lines.push_back(std::make_pair(line.substr(0, pos), label));
  77. //###
  78. std:: vector< std:: string> tokens;
  79. boost::char_separator< char> sep( " ");
  80. boost::tokenizer<boost::char_separator< char> > tok(line, sep);
  81. tokens.clear();
  82. std::copy(tok.begin(), tok.end(), std::back_inserter(tokens));
  83. for ( int i = 1; i < tokens.size(); ++i)
  84. labels.push_back(atof(tokens.at(i).c_str()));
  85. lines.push_back( std::make_pair(tokens.at( 0), labels));
  86. labels.clear();
  87. }
  88. if (FLAGS_shuffle) {
  89. // randomly shuffle data
  90. LOG(INFO) << "Shuffling data";
  91. shuffle(lines.begin(), lines.end());
  92. }
  93. LOG(INFO) << "A total of " << lines.size() << " images.";
  94. if (encode_type.size() && !encoded)
  95. LOG(INFO) << "encode_type specified, assuming encoded=true.";
  96. int resize_height = std::max< int>( 0, FLAGS_resize_height);
  97. int resize_width = std::max< int>( 0, FLAGS_resize_width);
  98. // Create new DB
  99. scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
  100. db->Open(argv[ 3], db::NEW);
  101. scoped_ptr<db::Transaction> txn(db->NewTransaction());
  102. // Storing to db
  103. std:: string root_folder(argv[1]);
  104. Datum datum;
  105. int count = 0;
  106. int data_size = 0;
  107. bool data_size_initialized = false;
  108. for ( int line_id = 0; line_id < lines.size(); ++line_id) {
  109. bool status;
  110. std:: string enc = encode_type;
  111. if (encoded && !enc.size()) {
  112. // Guess the encoding type from the file name
  113. string fn = lines[line_id].first;
  114. size_t p = fn.rfind( '.');
  115. if ( p == fn.npos )
  116. LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";
  117. enc = fn.substr(p+ 1);
  118. std::transform(enc.begin(), enc.end(), enc.begin(), :: tolower);
  119. }
  120. status = ReadImageToDatum(root_folder + lines[line_id].first, //### 没有修改,但是调用了新的函数
  121. lines[line_id].second, resize_height, resize_width, is_color,
  122. enc, &datum);
  123. if (status == false) continue;
  124. if (check_size) {
  125. if (!data_size_initialized) {
  126. data_size = datum.channels() * datum.height() * datum.width();
  127. data_size_initialized = true;
  128. } else {
  129. const std:: string& data = datum.data();
  130. CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
  131. << data.size();
  132. }
  133. }
  134. // sequential
  135. string key_str = caffe::format_int(line_id, 8) + "_" + lines[line_id].first;
  136. // Put in db
  137. string out;
  138. CHECK(datum.SerializeToString(&out));
  139. txn->Put(key_str, out);
  140. if (++count % 1000 == 0) {
  141. // Commit db
  142. txn->Commit();
  143. txn.reset(db->NewTransaction());
  144. LOG(INFO) << "Processed " << count << " files.";
  145. }
  146. }
  147. // write the last batch
  148. if (count % 1000 != 0) {
  149. txn->Commit();
  150. LOG(INFO) << "Processed " << count << " files.";
  151. }
  152. #else
  153. LOG(FATAL) << "This tool requires OpenCV; compile with USE_OPENCV.";
  154. #endif // USE_OPENCV
  155. return 0;
  156. }

2. include/caffe/util/io.hpp(部分):

找到ReadImageToDatum函数,并在下面添加(不要删除原函数)一个vector<float>类型的函数重载。


 
 
  1. bool ReadImageToDatum(const string& filename, const int label,
  2. const int height, const int width, const bool is_color,
  3. const std:: string & encoding, Datum* datum);
  4. bool ReadImageToDatum(const string& filename, const vector<float> labels, //###
  5. const int height, const int width, const bool is_color,
  6. const std:: string & encoding, Datum* datum);

3. src/caffe/util/io.cpp(部分):

添加上面函数的实现(这里注释掉的encoding部分与图像编码格式有关,默认encoding='',所以这部分可以直接注释掉)。


 
 
  1. bool ReadImageToDatum(const string& filename, const int label,
  2. const int height, const int width, const bool is_color,
  3. const std:: string & encoding, Datum* datum) {
  4. cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
  5. if (cv_img.data) {
  6. if (encoding.size()) {
  7. if ( (cv_img.channels() == 3) == is_color && !height && !width &&
  8. matchExt(filename, encoding) )
  9. return ReadFileToDatum(filename, label, datum);
  10. std:: vector<uchar> buf;
  11. cv::imencode( "."+encoding, cv_img, buf);
  12. datum->set_data( std:: string( reinterpret_cast< char*>(&buf[ 0]),
  13. buf.size()));
  14. datum->set_label(label);
  15. datum->set_encoded( true);
  16. return true;
  17. }
  18. CVMatToDatum(cv_img, datum);
  19. datum->set_label(label);
  20. return true;
  21. } else {
  22. return false;
  23. }
  24. }
  25. bool ReadImageToDatum(const string& filename, const vector<float> labels,
  26. const int height, const int width, const bool is_color,
  27. const std:: string & encoding, Datum* datum) {
  28. cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
  29. if (cv_img.data) {
  30. // if (encoding.size()) {
  31. // if ( (cv_img.channels() == 3) == is_color && !height && !width &&
  32. // matchExt(filename, encoding) )
  33. // return ReadFileToDatum(filename, label, datum);
  34. // std::vector<uchar> buf;
  35. // cv::imencode("."+encoding, cv_img, buf);
  36. // datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),
  37. // buf.size()));
  38. // datum->set_label(label);
  39. // datum->set_encoded(true);
  40. // return true;
  41. // }
  42. CVMatToDatum(cv_img, datum);
  43. // datum->set_label(label);
  44. //###
  45. for ( int i = 0; i < labels.size(); ++i)
  46. datum->add_float_data(labels.at(i));
  47. return true;
  48. } else {
  49. return false;
  50. }
  51. }

4. src/caffe/layers/data_layer.cpp:


 
 
  1. #ifdef USE_OPENCV
  2. #include <opencv2/core/core.hpp>
  3. #endif // USE_OPENCV
  4. #include <stdint.h>
  5. #include <vector>
  6. #include "caffe/data_transformer.hpp"
  7. #include "caffe/layers/data_layer.hpp"
  8. #include "caffe/util/benchmark.hpp"
  9. namespace caffe {
  10. template < typename Dtype>
  11. DataLayer<Dtype>::DataLayer( const LayerParameter& param)
  12. : BasePrefetchingDataLayer<Dtype>(param),
  13. offset_() {
  14. db_.reset(db::GetDB(param.data_param().backend()));
  15. db_->Open(param.data_param().source(), db::READ);
  16. cursor_.reset(db_->NewCursor());
  17. }
  18. template < typename Dtype>
  19. DataLayer<Dtype>::~DataLayer() {
  20. this->StopInternalThread();
  21. }
  22. template < typename Dtype>
  23. void DataLayer<Dtype>::DataLayerSetUp( const vector<Blob<Dtype>*>& bottom,
  24. const vector<Blob<Dtype>*>& top) {
  25. const int batch_size = this->layer_param_.data_param().batch_size();
  26. // Read a data point, and use it to initialize the top blob.
  27. Datum datum;
  28. datum.ParseFromString(cursor_->value());
  29. // Use data_transformer to infer the expected blob shape from datum.
  30. vector< int> top_shape = this->data_transformer_->InferBlobShape(datum);
  31. this->transformed_data_.Reshape(top_shape);
  32. // Reshape top[0] and prefetch_data according to the batch_size.
  33. top_shape[ 0] = batch_size;
  34. top[ 0]->Reshape(top_shape);
  35. for ( int i = 0; i < this->prefetch_.size(); ++i) {
  36. this->prefetch_[i]->data_.Reshape(top_shape);
  37. }
  38. LOG_IF(INFO, Caffe::root_solver())
  39. << "output data size: " << top[ 0]->num() << ","
  40. << top[ 0]->channels() << "," << top[ 0]->height() << ","
  41. << top[ 0]->width();
  42. //###
  43. // label
  44. // if (this->output_labels_) {
  45. // vector<int> label_shape(1, batch_size);
  46. // top[1]->Reshape(label_shape);
  47. // for (int i = 0; i < this->prefetch_.size(); ++i) {
  48. // this->prefetch_[i]->label_.Reshape(label_shape);
  49. // }
  50. // }
  51. //###
  52. const int label_num = this->layer_param_.data_param().label_num();
  53. if ( this->output_labels_) {
  54. vector< int> label_shape;
  55. label_shape.push_back(batch_size);
  56. label_shape.push_back(label_num);
  57. label_shape.push_back( 1);
  58. label_shape.push_back( 1);
  59. top[ 1]->Reshape(label_shape);
  60. for ( int i = 0; i < this->prefetch_.size(); ++i) {
  61. this->prefetch_[i]->label_.Reshape(label_shape);
  62. }
  63. }
  64. }
  65. template < typename Dtype>
  66. bool DataLayer<Dtype>::Skip() {
  67. int size = Caffe::solver_count();
  68. int rank = Caffe::solver_rank();
  69. bool keep = (offset_ % size) == rank ||
  70. // In test mode, only rank 0 runs, so avoid skipping
  71. this->layer_param_.phase() == TEST;
  72. return !keep;
  73. }
  74. template< typename Dtype>
  75. void DataLayer<Dtype>::Next() {
  76. cursor_->Next();
  77. if (!cursor_->valid()) {
  78. LOG_IF(INFO, Caffe::root_solver())
  79. << "Restarting data prefetching from start.";
  80. cursor_->SeekToFirst();
  81. }
  82. offset_++;
  83. }
  84. // This function is called on prefetch thread
  85. template< typename Dtype>
  86. void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
  87. CPUTimer batch_timer;
  88. batch_timer.Start();
  89. double read_time = 0;
  90. double trans_time = 0;
  91. CPUTimer timer;
  92. CHECK(batch->data_.count());
  93. CHECK( this->transformed_data_.count());
  94. const int batch_size = this->layer_param_.data_param().batch_size();
  95. Datum datum;
  96. for ( int item_id = 0; item_id < batch_size; ++item_id) {
  97. timer.Start();
  98. while (Skip()) {
  99. Next();
  100. }
  101. datum.ParseFromString(cursor_->value());
  102. read_time += timer.MicroSeconds();
  103. if (item_id == 0) {
  104. // Reshape according to the first datum of each batch
  105. // on single input batches allows for inputs of varying dimension.
  106. // Use data_transformer to infer the expected blob shape from datum.
  107. vector< int> top_shape = this->data_transformer_->InferBlobShape(datum);
  108. this->transformed_data_.Reshape(top_shape);
  109. // Reshape batch according to the batch_size.
  110. top_shape[ 0] = batch_size;
  111. batch->data_.Reshape(top_shape);
  112. }
  113. // Apply data transformations (mirror, scale, crop...)
  114. timer.Start();
  115. int offset = batch->data_.offset(item_id);
  116. Dtype* top_data = batch->data_.mutable_cpu_data();
  117. this->transformed_data_.set_cpu_data(top_data + offset);
  118. this->data_transformer_->Transform(datum, &( this->transformed_data_));
  119. //###
  120. // Copy label.
  121. // if (this->output_labels_) {
  122. // Dtype* top_label = batch->label_.mutable_cpu_data();
  123. // top_label[item_id] = datum.label();
  124. // }
  125. //###
  126. const int label_num = this->layer_param_.data_param().label_num();
  127. if ( this->output_labels_) {
  128. Dtype* top_label = batch->label_.mutable_cpu_data();
  129. for ( int i = 0; i < label_num; i++){
  130. top_label[item_id * label_num + i] = datum.float_data(i); //read float labels
  131. }
  132. }
  133. trans_time += timer.MicroSeconds();
  134. Next();
  135. }
  136. timer.Stop();
  137. batch_timer.Stop();
  138. DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  139. DLOG(INFO) << " Read time: " << read_time / 1000 << " ms.";
  140. DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
  141. }
  142. INSTANTIATE_CLASS(DataLayer);
  143. REGISTER_LAYER_CLASS(Data);
  144. } // namespace caffe

5. src/caffe/proto/caffe.proto(部分):

在DataParameter中添加label_num项。


 
 
  1. message DataParameter {
  2. enum DB {
  3. LEVELDB = 0;
  4. LMDB = 1;
  5. }
  6. // Specify the data source.
  7. optional string source = 1;
  8. // Specify the batch size.
  9. optional uint32 batch_size = 4;
  10. // The rand_skip variable is for the data layer to skip a few data points
  11. // to avoid all asynchronous sgd clients to start at the same point. The skip
  12. // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
  13. // be larger than the number of keys in the database.
  14. // DEPRECATED. Each solver accesses a different subset of the database.
  15. optional uint32 rand_skip = 7 [ default = 0];
  16. optional DB backend = 8 [ default = LEVELDB];
  17. // DEPRECATED. See TransformationParameter. For data pre-processing, we can do
  18. // simple scaling and subtracting the data mean, if provided. Note that the
  19. // mean subtraction is always carried out before scaling.
  20. optional float scale = 2 [ default = 1];
  21. optional string mean_file = 3;
  22. // DEPRECATED. See TransformationParameter. Specify if we would like to randomly
  23. // crop an image.
  24. optional uint32 crop_size = 5 [ default = 0];
  25. // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror
  26. // data.
  27. optional bool mirror = 6 [ default = false];
  28. // Force the encoded image to have 3 color channels
  29. optional bool force_encoded_color = 9 [ default = false];
  30. // Prefetch queue (Increase if data feeding bandwidth varies, within the
  31. // limit of device memory for GPU training)
  32. optional uint32 prefetch = 10 [ default = 4];
  33. //### For multi-task training
  34. optional uint32 label_num = 11 [ default = 1];
  35. }

 

到这里为止,程序就修改完成了,重新编译caffe,会生成新的build/tools/convert_imageset_multi可执行文件,利用它就可以读取开头所说的那种多标签label文件。

关于调用方法,下一篇再说。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值