caffe学习记录(二):多标签分类/回归训练(上)

多标签分类在工程上有很多应用,例如,输入一张图片,判断这个人的年龄、性别和是否配戴眼镜。这时,数据集的label文件应当具有这样的格式:

  • 000001.jpg 22 1 0
  • 000002.jpg 30 1 1
  • 000003.jpg 44 0 1
  • 000004.jpg 17 0 0

假定第一个数字表示年龄,第二个0/1表示女/男,第三个0/1表示不戴眼镜/戴眼镜。

同样地,回归问题在CNN中也有很多应用,例如,在物体检测领域(例如Faster R-CNN、SSD、Yolo等网络),输入一张图片,输出bounding box的平移缩放参数(当然,实际情况会比这复杂一点。。)。这时,数据集的label文件应当具有这样的格式:

  • 000001.jpg -0.01 -0.29 -0.05 0.02
  • 000002.jpg 0.25 -0.04 0.02 -0.07
  • 000003.jpg 0.23 0.05 -0.22 0.15
  • 000004.jpg 0.20 -0.12 0.02 -0.20

四个参数分别表示x、y、w、h的变化量。

(实际上,上面是从MTCNN人脸检测网络的训练数据中截取的bounding box回归部分,关于MTCNN训练数据格式下一篇会提到)

--------------------------------------------------------------------------------

caffe对单标签分类问题提供了很好的支持,但是对多标签分类/回归问题的支持却不是很好,用于生成lmdb数据集的可执行文件build/tools/convert_imageset只能接受单标签和整形数,显然,多标签分类/回归问题需要程序接受多标签和浮点型数,这就是我们要做的工作。

需要注意的是,即使用浮点型数读入分类标签,也不会影响分类结果。

我希望能够继续使用lmdb数据集,这就需要修改一些caffe源码,虽然修改过程复杂了点,但因为基本是一劳永逸的,所以还是值得一试的。

 

参考博客:用 caffe 做回归 (上)

这位dalao的博客基本上可以解决问题,我对他做过的修改进行了汇总,总的来说,共需要修改以下几个文件:

  • tools/convert_imageset.cpp:用于生成lmdb数据集。将该文件复制一份,重命名为tools/convert_imageset_multi.cpp并修改该文件。
  • include/caffe/util/io.hpp:为上面的程序提供新的函数
  • src/caffe/util/io.cpp
  • src/caffe/layers/data_layer.cpp:用于加载lmdb数据集
  • src/caffe/proto/caffe.proto:用于添加label_num参数

首先声明,我用的是BVLC版本的caffe(2018.7):https://github.com/BVLC/caffe

如果caffe版本不一样,这些文件很可能不能直接使用。接下来就是这些文件的修改版,其中,//###表示修改的部分,主要思路就是将int型数据转变为vector<float>型数据

1. tools/convert_imageset_multi.cpp:

// This program converts a set of images to a lmdb/leveldb by storing them
// as Datum proto buffers.
// Usage:
//   convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
//
// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
// should be a list of files as well as their labels, in the format as
//   subfolder1/file1.JPEG 7
//   ....

#include <algorithm>
#include <fstream>  // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>

#include "boost/scoped_ptr.hpp"
#include "gflags/gflags.h"
#include "glog/logging.h"

#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/rng.hpp"

#include <iostream>             //###
#include <boost/tokenizer.hpp>  //###

using namespace caffe;  // NOLINT(build/namespaces)
using std::pair;
using boost::scoped_ptr;

DEFINE_bool(gray, false,
    "When this option is on, treat images as grayscale ones");
DEFINE_bool(shuffle, false,
    "Randomly shuffle the order of images and their labels");
DEFINE_string(backend, "lmdb",
        "The backend {lmdb, leveldb} for storing the result");
DEFINE_int32(resize_width, 0, "Width images are resized to");
DEFINE_int32(resize_height, 0, "Height images are resized to");
DEFINE_bool(check_size, false,
    "When this option is on, check that all the datum have the same size");
DEFINE_bool(encoded, false,
    "When this option is on, the encoded image will be save in datum");
DEFINE_string(encode_type, "",
    "Optional: What type should we encode the image as ('png','jpg',...).");

int main(int argc, char** argv) {
#ifdef USE_OPENCV
  ::google::InitGoogleLogging(argv[0]);
  // Print output to stderr (while still logging)
  FLAGS_alsologtostderr = 1;

#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
        "format used as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  if (argc < 4) {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
    return 1;
  }

  const bool is_color = !FLAGS_gray;
  const bool check_size = FLAGS_check_size;
  const bool encoded = FLAGS_encoded;
  const string encode_type = FLAGS_encode_type;

  std::ifstream infile(argv[2]);
  // std::vector<std::pair<std::string, int> > lines;
  std::vector<std::pair<std::string, std::vector<float> > > lines;  //### int -> vector<float>
  std::string line;
  size_t pos;
  // int label;               //###
  std::vector<float> labels;  //###
  while (std::getline(infile, line)) {
    //###
    // pos = line.find_last_of(' ');
    // label = atoi(line.substr(pos + 1).c_str());
    // lines.push_back(std::make_pair(line.substr(0, pos), label));

    //###
    std::vector<std::string> tokens;
    boost::char_separator<char> sep(" ");
    boost::tokenizer<boost::char_separator<char> > tok(line, sep);
    tokens.clear();
    std::copy(tok.begin(), tok.end(), std::back_inserter(tokens));  
 
    for (int i = 1; i < tokens.size(); ++i)
      labels.push_back(atof(tokens.at(i).c_str()));
    lines.push_back(std::make_pair(tokens.at(0), labels));
    labels.clear();
  }
  if (FLAGS_shuffle) {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  if (encode_type.size() && !encoded)
    LOG(INFO) << "encode_type specified, assuming encoded=true.";

  int resize_height = std::max<int>(0, FLAGS_resize_height);
  int resize_width = std::max<int>(0, FLAGS_resize_width);

  // Create new DB
  scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
  db->Open(argv[3], db::NEW);
  scoped_ptr<db::Transaction> txn(db->NewTransaction());

  // Storing to db
  std::string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  int data_size = 0;
  bool data_size_initialized = false;

  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    bool status;

    std::string enc = encode_type;
    if (encoded && !enc.size()) {
      // Guess the encoding type from the file name
      string fn = lines[line_id].first;
      size_t p = fn.rfind('.');
      if ( p == fn.npos )
        LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";
      enc = fn.substr(p+1);
      std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);
    }
    status = ReadImageToDatum(root_folder + lines[line_id].first,  //### 没有修改,但是调用了新的函数
        lines[line_id].second, resize_height, resize_width, is_color,
        enc, &datum);
    if (status == false) continue;
    if (check_size) {
      if (!data_size_initialized) {
        data_size = datum.channels() * datum.height() * datum.width();
        data_size_initialized = true;
      } else {
        const std::string& data = datum.data();
        CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
            << data.size();
      }
    }
    // sequential
    string key_str = caffe::format_int(line_id, 8) + "_" + lines[line_id].first;

    // Put in db
    string out;
    CHECK(datum.SerializeToString(&out));
    txn->Put(key_str, out);

    if (++count % 1000 == 0) {
      // Commit db
      txn->Commit();
      txn.reset(db->NewTransaction());
      LOG(INFO) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    txn->Commit();
    LOG(INFO) << "Processed " << count << " files.";
  }
#else
  LOG(FATAL) << "This tool requires OpenCV; compile with USE_OPENCV.";
#endif  // USE_OPENCV
  return 0;
}

2. include/caffe/util/io.hpp(部分):

找到ReadImageToDatum函数,并在下面添加(不要删除原函数)一个vector<float>类型的函数重载。

bool ReadImageToDatum(const string& filename, const int label,
    const int height, const int width, const bool is_color,
    const std::string & encoding, Datum* datum);

bool ReadImageToDatum(const string& filename, const vector<float> labels,  //###
    const int height, const int width, const bool is_color,
    const std::string & encoding, Datum* datum);

3. src/caffe/util/io.cpp(部分):

添加上面函数的实现(这里注释掉的encoding部分与图像编码格式有关,默认encoding='',所以这部分可以直接注释掉)。

bool ReadImageToDatum(const string& filename, const int label,
    const int height, const int width, const bool is_color,
    const std::string & encoding, Datum* datum) {
  cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
  if (cv_img.data) {
    if (encoding.size()) {
      if ( (cv_img.channels() == 3) == is_color && !height && !width &&
          matchExt(filename, encoding) )
        return ReadFileToDatum(filename, label, datum);
      std::vector<uchar> buf;
      cv::imencode("."+encoding, cv_img, buf);
      datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),
                      buf.size()));
      datum->set_label(label);
      datum->set_encoded(true);
      return true;
    }
    CVMatToDatum(cv_img, datum);
    datum->set_label(label);
    return true;
  } else {
    return false;
  }
}

bool ReadImageToDatum(const string& filename, const vector<float> labels,
    const int height, const int width, const bool is_color,
    const std::string & encoding, Datum* datum) {
  cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
  if (cv_img.data) {
    // if (encoding.size()) {
    //   if ( (cv_img.channels() == 3) == is_color && !height && !width &&
    //       matchExt(filename, encoding) )
    //     return ReadFileToDatum(filename, label, datum);
    //   std::vector<uchar> buf;
    //   cv::imencode("."+encoding, cv_img, buf);
    //   datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),
    //                   buf.size()));
    //   datum->set_label(label);
    //   datum->set_encoded(true);
    //   return true;
    // }
    CVMatToDatum(cv_img, datum);
    // datum->set_label(label);

    //###
    for (int i = 0; i < labels.size(); ++i)
      datum->add_float_data(labels.at(i));
 
    return true;
  } else {
    return false;
  }
}

4. src/caffe/layers/data_layer.cpp:

#ifdef USE_OPENCV
#include <opencv2/core/core.hpp>
#endif  // USE_OPENCV
#include <stdint.h>

#include <vector>

#include "caffe/data_transformer.hpp"
#include "caffe/layers/data_layer.hpp"
#include "caffe/util/benchmark.hpp"

namespace caffe {

template <typename Dtype>
DataLayer<Dtype>::DataLayer(const LayerParameter& param)
  : BasePrefetchingDataLayer<Dtype>(param),
    offset_() {
  db_.reset(db::GetDB(param.data_param().backend()));
  db_->Open(param.data_param().source(), db::READ);
  cursor_.reset(db_->NewCursor());
}

template <typename Dtype>
DataLayer<Dtype>::~DataLayer() {
  this->StopInternalThread();
}

template <typename Dtype>
void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
  const int batch_size = this->layer_param_.data_param().batch_size();
  // Read a data point, and use it to initialize the top blob.
  Datum datum;
  datum.ParseFromString(cursor_->value());

  // Use data_transformer to infer the expected blob shape from datum.
  vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
  this->transformed_data_.Reshape(top_shape);
  // Reshape top[0] and prefetch_data according to the batch_size.
  top_shape[0] = batch_size;
  top[0]->Reshape(top_shape);
  for (int i = 0; i < this->prefetch_.size(); ++i) {
    this->prefetch_[i]->data_.Reshape(top_shape);
  }
  LOG_IF(INFO, Caffe::root_solver())
      << "output data size: " << top[0]->num() << ","
      << top[0]->channels() << "," << top[0]->height() << ","
      << top[0]->width();

  //###
  // label
  // if (this->output_labels_) {
  //   vector<int> label_shape(1, batch_size);
  //   top[1]->Reshape(label_shape);
  //   for (int i = 0; i < this->prefetch_.size(); ++i) {
  //     this->prefetch_[i]->label_.Reshape(label_shape);
  //   }
  // }

  //###
  const int label_num = this->layer_param_.data_param().label_num();
  if (this->output_labels_) {
    vector<int> label_shape;
    label_shape.push_back(batch_size);
    label_shape.push_back(label_num);
    label_shape.push_back(1);
    label_shape.push_back(1);
    top[1]->Reshape(label_shape);
    for (int i = 0; i < this->prefetch_.size(); ++i) {
      this->prefetch_[i]->label_.Reshape(label_shape);
    }
  }
}

template <typename Dtype>
bool DataLayer<Dtype>::Skip() {
  int size = Caffe::solver_count();
  int rank = Caffe::solver_rank();
  bool keep = (offset_ % size) == rank ||
              // In test mode, only rank 0 runs, so avoid skipping
              this->layer_param_.phase() == TEST;
  return !keep;
}

template<typename Dtype>
void DataLayer<Dtype>::Next() {
  cursor_->Next();
  if (!cursor_->valid()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Restarting data prefetching from start.";
    cursor_->SeekToFirst();
  }
  offset_++;
}

// This function is called on prefetch thread
template<typename Dtype>
void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
  CPUTimer batch_timer;
  batch_timer.Start();
  double read_time = 0;
  double trans_time = 0;
  CPUTimer timer;
  CHECK(batch->data_.count());
  CHECK(this->transformed_data_.count());
  const int batch_size = this->layer_param_.data_param().batch_size();

  Datum datum;
  for (int item_id = 0; item_id < batch_size; ++item_id) {
    timer.Start();
    while (Skip()) {
      Next();
    }
    datum.ParseFromString(cursor_->value());
    read_time += timer.MicroSeconds();

    if (item_id == 0) {
      // Reshape according to the first datum of each batch
      // on single input batches allows for inputs of varying dimension.
      // Use data_transformer to infer the expected blob shape from datum.
      vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
      this->transformed_data_.Reshape(top_shape);
      // Reshape batch according to the batch_size.
      top_shape[0] = batch_size;
      batch->data_.Reshape(top_shape);
    }

    // Apply data transformations (mirror, scale, crop...)
    timer.Start();
    int offset = batch->data_.offset(item_id);
    Dtype* top_data = batch->data_.mutable_cpu_data();
    this->transformed_data_.set_cpu_data(top_data + offset);
    this->data_transformer_->Transform(datum, &(this->transformed_data_));

    //###
    // Copy label.
    // if (this->output_labels_) {
    //   Dtype* top_label = batch->label_.mutable_cpu_data();
    //   top_label[item_id] = datum.label();
    // }

    //###
    const int label_num = this->layer_param_.data_param().label_num();
    if (this->output_labels_) {
      Dtype* top_label = batch->label_.mutable_cpu_data();
      for (int i = 0; i < label_num; i++){
        top_label[item_id * label_num + i] = datum.float_data(i);  //read float labels
      }
    }

    trans_time += timer.MicroSeconds();
    Next();
  }
  timer.Stop();
  batch_timer.Stop();
  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}

INSTANTIATE_CLASS(DataLayer);
REGISTER_LAYER_CLASS(Data);

}  // namespace caffe

5. src/caffe/proto/caffe.proto(部分):

在DataParameter中添加label_num项。

message DataParameter {
  enum DB {
    LEVELDB = 0;
    LMDB = 1;
  }
  // Specify the data source.
  optional string source = 1;
  // Specify the batch size.
  optional uint32 batch_size = 4;
  // The rand_skip variable is for the data layer to skip a few data points
  // to avoid all asynchronous sgd clients to start at the same point. The skip
  // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
  // be larger than the number of keys in the database.
  // DEPRECATED. Each solver accesses a different subset of the database.
  optional uint32 rand_skip = 7 [default = 0];
  optional DB backend = 8 [default = LEVELDB];
  // DEPRECATED. See TransformationParameter. For data pre-processing, we can do
  // simple scaling and subtracting the data mean, if provided. Note that the
  // mean subtraction is always carried out before scaling.
  optional float scale = 2 [default = 1];
  optional string mean_file = 3;
  // DEPRECATED. See TransformationParameter. Specify if we would like to randomly
  // crop an image.
  optional uint32 crop_size = 5 [default = 0];
  // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror
  // data.
  optional bool mirror = 6 [default = false];
  // Force the encoded image to have 3 color channels
  optional bool force_encoded_color = 9 [default = false];
  // Prefetch queue (Increase if data feeding bandwidth varies, within the
  // limit of device memory for GPU training)
  optional uint32 prefetch = 10 [default = 4];
  //### For multi-task training
  optional uint32 label_num = 11 [default = 1];
}

 

到这里为止,程序就修改完成了,重新编译caffe,会生成新的build/tools/convert_imageset_multi可执行文件,利用它就可以读取开头所说的那种多标签label文件。

关于调用方法,下一篇再说。

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值