多标签分类在工程上有很多应用,例如,输入一张图片,判断这个人的年龄、性别和是否配戴眼镜。这时,数据集的label文件应当具有这样的格式:
- 000001.jpg 22 1 0
- 000002.jpg 30 1 1
- 000003.jpg 44 0 1
- 000004.jpg 17 0 0
假定第一个数字表示年龄,第二个0/1表示女/男,第三个0/1表示不戴眼镜/戴眼镜。
同样地,回归问题在CNN中也有很多应用,例如,在物体检测领域(例如Faster R-CNN、SSD、Yolo等网络),输入一张图片,输出bounding box的平移缩放参数(当然,实际情况会比这复杂一点。。)。这时,数据集的label文件应当具有这样的格式:
- 000001.jpg -0.01 -0.29 -0.05 0.02
- 000002.jpg 0.25 -0.04 0.02 -0.07
- 000003.jpg 0.23 0.05 -0.22 0.15
- 000004.jpg 0.20 -0.12 0.02 -0.20
四个参数分别表示x、y、w、h的变化量。
(实际上,上面是从MTCNN人脸检测网络的训练数据中截取的bounding box回归部分,关于MTCNN训练数据格式下一篇会提到)
--------------------------------------------------------------------------------
caffe对单标签分类问题提供了很好的支持,但是对多标签分类/回归问题的支持却不是很好,用于生成lmdb数据集的可执行文件build/tools/convert_imageset只能接受单标签和整形数,显然,多标签分类/回归问题需要程序接受多标签和浮点型数,这就是我们要做的工作。
需要注意的是,即使用浮点型数读入分类标签,也不会影响分类结果。
我希望能够继续使用lmdb数据集,这就需要修改一些caffe源码,虽然修改过程复杂了点,但因为基本是一劳永逸的,所以还是值得一试的。
参考博客:用 caffe 做回归 (上)
这位dalao的博客基本上可以解决问题,我对他做过的修改进行了汇总,总的来说,共需要修改以下几个文件:
- tools/convert_imageset.cpp:用于生成lmdb数据集。将该文件复制一份,重命名为tools/convert_imageset_multi.cpp并修改该文件。
- include/caffe/util/io.hpp:为上面的程序提供新的函数
- src/caffe/util/io.cpp
- src/caffe/layers/data_layer.cpp:用于加载lmdb数据集
- src/caffe/proto/caffe.proto:用于添加label_num参数
首先声明,我用的是BVLC版本的caffe(2018.7):https://github.com/BVLC/caffe
如果caffe版本不一样,这些文件很可能不能直接使用。接下来就是这些文件的修改版,其中,//###表示修改的部分,主要思路就是将int型数据转变为vector<float>型数据:
1. tools/convert_imageset_multi.cpp:
// This program converts a set of images to a lmdb/leveldb by storing them
// as Datum proto buffers.
// Usage:
// convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
//
// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
// should be a list of files as well as their labels, in the format as
// subfolder1/file1.JPEG 7
// ....
#include <algorithm>
#include <fstream> // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>
#include "boost/scoped_ptr.hpp"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/rng.hpp"
#include <iostream> //###
#include <boost/tokenizer.hpp> //###
using namespace caffe; // NOLINT(build/namespaces)
using std::pair;
using boost::scoped_ptr;
DEFINE_bool(gray, false,
"When this option is on, treat images as grayscale ones");
DEFINE_bool(shuffle, false,
"Randomly shuffle the order of images and their labels");
DEFINE_string(backend, "lmdb",
"The backend {lmdb, leveldb} for storing the result");
DEFINE_int32(resize_width, 0, "Width images are resized to");
DEFINE_int32(resize_height, 0, "Height images are resized to");
DEFINE_bool(check_size, false,
"When this option is on, check that all the datum have the same size");
DEFINE_bool(encoded, false,
"When this option is on, the encoded image will be save in datum");
DEFINE_string(encode_type, "",
"Optional: What type should we encode the image as ('png','jpg',...).");
int main(int argc, char** argv) {
#ifdef USE_OPENCV
::google::InitGoogleLogging(argv[0]);
// Print output to stderr (while still logging)
FLAGS_alsologtostderr = 1;
#ifndef GFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
"format used as input for Caffe.\n"
"Usage:\n"
" convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
"The ImageNet dataset for the training demo is at\n"
" http://www.image-net.org/download-images\n");
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (argc < 4) {
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
return 1;
}
const bool is_color = !FLAGS_gray;
const bool check_size = FLAGS_check_size;
const bool encoded = FLAGS_encoded;
const string encode_type = FLAGS_encode_type;
std::ifstream infile(argv[2]);
// std::vector<std::pair<std::string, int> > lines;
std::vector<std::pair<std::string, std::vector<float> > > lines; //### int -> vector<float>
std::string line;
size_t pos;
// int label; //###
std::vector<float> labels; //###
while (std::getline(infile, line)) {
//###
// pos = line.find_last_of(' ');
// label = atoi(line.substr(pos + 1).c_str());
// lines.push_back(std::make_pair(line.substr(0, pos), label));
//###
std::vector<std::string> tokens;
boost::char_separator<char> sep(" ");
boost::tokenizer<boost::char_separator<char> > tok(line, sep);
tokens.clear();
std::copy(tok.begin(), tok.end(), std::back_inserter(tokens));
for (int i = 1; i < tokens.size(); ++i)
labels.push_back(atof(tokens.at(i).c_str()));
lines.push_back(std::make_pair(tokens.at(0), labels));
labels.clear();
}
if (FLAGS_shuffle) {
// randomly shuffle data
LOG(INFO) << "Shuffling data";
shuffle(lines.begin(), lines.end());
}
LOG(INFO) << "A total of " << lines.size() << " images.";
if (encode_type.size() && !encoded)
LOG(INFO) << "encode_type specified, assuming encoded=true.";
int resize_height = std::max<int>(0, FLAGS_resize_height);
int resize_width = std::max<int>(0, FLAGS_resize_width);
// Create new DB
scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
db->Open(argv[3], db::NEW);
scoped_ptr<db::Transaction> txn(db->NewTransaction());
// Storing to db
std::string root_folder(argv[1]);
Datum datum;
int count = 0;
int data_size = 0;
bool data_size_initialized = false;
for (int line_id = 0; line_id < lines.size(); ++line_id) {
bool status;
std::string enc = encode_type;
if (encoded && !enc.size()) {
// Guess the encoding type from the file name
string fn = lines[line_id].first;
size_t p = fn.rfind('.');
if ( p == fn.npos )
LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";
enc = fn.substr(p+1);
std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);
}
status = ReadImageToDatum(root_folder + lines[line_id].first, //### 没有修改,但是调用了新的函数
lines[line_id].second, resize_height, resize_width, is_color,
enc, &datum);
if (status == false) continue;
if (check_size) {
if (!data_size_initialized) {
data_size = datum.channels() * datum.height() * datum.width();
data_size_initialized = true;
} else {
const std::string& data = datum.data();
CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
<< data.size();
}
}
// sequential
string key_str = caffe::format_int(line_id, 8) + "_" + lines[line_id].first;
// Put in db
string out;
CHECK(datum.SerializeToString(&out));
txn->Put(key_str, out);
if (++count % 1000 == 0) {
// Commit db
txn->Commit();
txn.reset(db->NewTransaction());
LOG(INFO) << "Processed " << count << " files.";
}
}
// write the last batch
if (count % 1000 != 0) {
txn->Commit();
LOG(INFO) << "Processed " << count << " files.";
}
#else
LOG(FATAL) << "This tool requires OpenCV; compile with USE_OPENCV.";
#endif // USE_OPENCV
return 0;
}
2. include/caffe/util/io.hpp(部分):
找到ReadImageToDatum函数,并在下面添加(不要删除原函数)一个vector<float>类型的函数重载。
bool ReadImageToDatum(const string& filename, const int label,
const int height, const int width, const bool is_color,
const std::string & encoding, Datum* datum);
bool ReadImageToDatum(const string& filename, const vector<float> labels, //###
const int height, const int width, const bool is_color,
const std::string & encoding, Datum* datum);
3. src/caffe/util/io.cpp(部分):
添加上面函数的实现(这里注释掉的encoding部分与图像编码格式有关,默认encoding='',所以这部分可以直接注释掉)。
bool ReadImageToDatum(const string& filename, const int label,
const int height, const int width, const bool is_color,
const std::string & encoding, Datum* datum) {
cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
if (cv_img.data) {
if (encoding.size()) {
if ( (cv_img.channels() == 3) == is_color && !height && !width &&
matchExt(filename, encoding) )
return ReadFileToDatum(filename, label, datum);
std::vector<uchar> buf;
cv::imencode("."+encoding, cv_img, buf);
datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),
buf.size()));
datum->set_label(label);
datum->set_encoded(true);
return true;
}
CVMatToDatum(cv_img, datum);
datum->set_label(label);
return true;
} else {
return false;
}
}
bool ReadImageToDatum(const string& filename, const vector<float> labels,
const int height, const int width, const bool is_color,
const std::string & encoding, Datum* datum) {
cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
if (cv_img.data) {
// if (encoding.size()) {
// if ( (cv_img.channels() == 3) == is_color && !height && !width &&
// matchExt(filename, encoding) )
// return ReadFileToDatum(filename, label, datum);
// std::vector<uchar> buf;
// cv::imencode("."+encoding, cv_img, buf);
// datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),
// buf.size()));
// datum->set_label(label);
// datum->set_encoded(true);
// return true;
// }
CVMatToDatum(cv_img, datum);
// datum->set_label(label);
//###
for (int i = 0; i < labels.size(); ++i)
datum->add_float_data(labels.at(i));
return true;
} else {
return false;
}
}
4. src/caffe/layers/data_layer.cpp:
#ifdef USE_OPENCV
#include <opencv2/core/core.hpp>
#endif // USE_OPENCV
#include <stdint.h>
#include <vector>
#include "caffe/data_transformer.hpp"
#include "caffe/layers/data_layer.hpp"
#include "caffe/util/benchmark.hpp"
namespace caffe {
template <typename Dtype>
DataLayer<Dtype>::DataLayer(const LayerParameter& param)
: BasePrefetchingDataLayer<Dtype>(param),
offset_() {
db_.reset(db::GetDB(param.data_param().backend()));
db_->Open(param.data_param().source(), db::READ);
cursor_.reset(db_->NewCursor());
}
template <typename Dtype>
DataLayer<Dtype>::~DataLayer() {
this->StopInternalThread();
}
template <typename Dtype>
void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int batch_size = this->layer_param_.data_param().batch_size();
// Read a data point, and use it to initialize the top blob.
Datum datum;
datum.ParseFromString(cursor_->value());
// Use data_transformer to infer the expected blob shape from datum.
vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
this->transformed_data_.Reshape(top_shape);
// Reshape top[0] and prefetch_data according to the batch_size.
top_shape[0] = batch_size;
top[0]->Reshape(top_shape);
for (int i = 0; i < this->prefetch_.size(); ++i) {
this->prefetch_[i]->data_.Reshape(top_shape);
}
LOG_IF(INFO, Caffe::root_solver())
<< "output data size: " << top[0]->num() << ","
<< top[0]->channels() << "," << top[0]->height() << ","
<< top[0]->width();
//###
// label
// if (this->output_labels_) {
// vector<int> label_shape(1, batch_size);
// top[1]->Reshape(label_shape);
// for (int i = 0; i < this->prefetch_.size(); ++i) {
// this->prefetch_[i]->label_.Reshape(label_shape);
// }
// }
//###
const int label_num = this->layer_param_.data_param().label_num();
if (this->output_labels_) {
vector<int> label_shape;
label_shape.push_back(batch_size);
label_shape.push_back(label_num);
label_shape.push_back(1);
label_shape.push_back(1);
top[1]->Reshape(label_shape);
for (int i = 0; i < this->prefetch_.size(); ++i) {
this->prefetch_[i]->label_.Reshape(label_shape);
}
}
}
template <typename Dtype>
bool DataLayer<Dtype>::Skip() {
int size = Caffe::solver_count();
int rank = Caffe::solver_rank();
bool keep = (offset_ % size) == rank ||
// In test mode, only rank 0 runs, so avoid skipping
this->layer_param_.phase() == TEST;
return !keep;
}
template<typename Dtype>
void DataLayer<Dtype>::Next() {
cursor_->Next();
if (!cursor_->valid()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Restarting data prefetching from start.";
cursor_->SeekToFirst();
}
offset_++;
}
// This function is called on prefetch thread
template<typename Dtype>
void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
CPUTimer batch_timer;
batch_timer.Start();
double read_time = 0;
double trans_time = 0;
CPUTimer timer;
CHECK(batch->data_.count());
CHECK(this->transformed_data_.count());
const int batch_size = this->layer_param_.data_param().batch_size();
Datum datum;
for (int item_id = 0; item_id < batch_size; ++item_id) {
timer.Start();
while (Skip()) {
Next();
}
datum.ParseFromString(cursor_->value());
read_time += timer.MicroSeconds();
if (item_id == 0) {
// Reshape according to the first datum of each batch
// on single input batches allows for inputs of varying dimension.
// Use data_transformer to infer the expected blob shape from datum.
vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
this->transformed_data_.Reshape(top_shape);
// Reshape batch according to the batch_size.
top_shape[0] = batch_size;
batch->data_.Reshape(top_shape);
}
// Apply data transformations (mirror, scale, crop...)
timer.Start();
int offset = batch->data_.offset(item_id);
Dtype* top_data = batch->data_.mutable_cpu_data();
this->transformed_data_.set_cpu_data(top_data + offset);
this->data_transformer_->Transform(datum, &(this->transformed_data_));
//###
// Copy label.
// if (this->output_labels_) {
// Dtype* top_label = batch->label_.mutable_cpu_data();
// top_label[item_id] = datum.label();
// }
//###
const int label_num = this->layer_param_.data_param().label_num();
if (this->output_labels_) {
Dtype* top_label = batch->label_.mutable_cpu_data();
for (int i = 0; i < label_num; i++){
top_label[item_id * label_num + i] = datum.float_data(i); //read float labels
}
}
trans_time += timer.MicroSeconds();
Next();
}
timer.Stop();
batch_timer.Stop();
DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
DLOG(INFO) << " Read time: " << read_time / 1000 << " ms.";
DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}
INSTANTIATE_CLASS(DataLayer);
REGISTER_LAYER_CLASS(Data);
} // namespace caffe
5. src/caffe/proto/caffe.proto(部分):
在DataParameter中添加label_num项。
message DataParameter {
enum DB {
LEVELDB = 0;
LMDB = 1;
}
// Specify the data source.
optional string source = 1;
// Specify the batch size.
optional uint32 batch_size = 4;
// The rand_skip variable is for the data layer to skip a few data points
// to avoid all asynchronous sgd clients to start at the same point. The skip
// point would be set as rand_skip * rand(0,1). Note that rand_skip should not
// be larger than the number of keys in the database.
// DEPRECATED. Each solver accesses a different subset of the database.
optional uint32 rand_skip = 7 [default = 0];
optional DB backend = 8 [default = LEVELDB];
// DEPRECATED. See TransformationParameter. For data pre-processing, we can do
// simple scaling and subtracting the data mean, if provided. Note that the
// mean subtraction is always carried out before scaling.
optional float scale = 2 [default = 1];
optional string mean_file = 3;
// DEPRECATED. See TransformationParameter. Specify if we would like to randomly
// crop an image.
optional uint32 crop_size = 5 [default = 0];
// DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror
// data.
optional bool mirror = 6 [default = false];
// Force the encoded image to have 3 color channels
optional bool force_encoded_color = 9 [default = false];
// Prefetch queue (Increase if data feeding bandwidth varies, within the
// limit of device memory for GPU training)
optional uint32 prefetch = 10 [default = 4];
//### For multi-task training
optional uint32 label_num = 11 [default = 1];
}
到这里为止,程序就修改完成了,重新编译caffe,会生成新的build/tools/convert_imageset_multi可执行文件,利用它就可以读取开头所说的那种多标签label文件。
关于调用方法,下一篇再说。