搞了这么长时间DeepLearning,打算用回忆的方式,进行一下总结。我们知道,caffe只能识别leveldb或者lmdb格式的文件,所以一切从数据转换开始。若想自己写转换函数程序(matlab/Python),自然而然需要读懂caffe中examples里转换的函数。
下面是mnist_convert.cpp的程序:
/***********************************************************************************************************************************
TIPS:caffe为什么采用lmdb或者leveldb,而不是直接读取原始数据呢?
一方面,数据类型五花八门,种类繁多,不可能用一套代码实现所有类型的输入数据要求,转换为统一格式可以简化数据读取层的实现;另一方面,使用leveldb或者lmdb可以提高磁盘IO利用率。
/************************************************************************************************************************************
引用相应的文件和命名空间:
//
//该程序将mnist数据集转换为caffe需要的格式(lmdb)
//用法:mnist_convert_data input_folder output_db_file
#include <gflags/gflags.h> //gflags命令行参数解析的头文件
#include <glog/logging.h> //记录程序日志的glog头文件
#include <google/protobuf/text_format.h> //解析proto类型文件中,解析prototxt类型的头文件
#if defined(USE_LEVELDB) && defined(USE_LMDB)
#include <leveldb/db.h> //引入leveldb类型数据头文件
#include <leveldb/write_batch.h> //引入leveldb类型数据写入头文件
#include <lmdb.h>
#endif
#if defined(_MSC_VER)
#include <direct.h>
#define mkdir(X, Y) _mkdir(X)
#endif
#include <stdint.h>
#include <sys/stat.h>
#include <fstream> //NOLINT(readability/streams)
#include <string>
#include "boost/scoped_ptr.hpp"
#include "caffe/proto/caffe.pb.h" //解析caffe中proto类型文件的头文件
#include "caffe/util/db.hpp"
#include "caffe/util/format.hpp"
#if defined(USE_LEVELDB) && defined(USE_LMDB)
using namespace caffe; //NOLINT(build/namespaces)
using boost::scoped_ptr;
using std::string;
定义backend(程序变量):
大端字节存储的二进制文件与小端字节存储的二进制文件转换:
/****************************************************************************************************************************************************************************************
TIPS:为何需要两种二进制文件转换?
大小端字节的计算机存储的二进制文件格式不同,大端计算机无法读取小端计算机存储的二进制文件(小端一样)所以需要两种文件的转换。
/*****************************************************************************************************************************************************************************************
uint32_t swap_endian(uint32_t val)
{
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
return (val << 16) | (val >> 16);
}
convert_dataset函数(核心代码):
void convert_dataset(const char* image_filename, const char* label_filename,const char* db_path, const string& db_backend)
{
//打开(二进制)文件
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
//CHECK用于检测文件是否正常打开
CHECK(image_file) << "Unable to open file " << image_filename;
CHECK(label_file) << "Unable to open file " << label_filename;
//根据mnist图像结构,定义长、宽、样本数、标签数
//uint32_t是自定义数据类型,unsigned int 32是指每个int32整数占用4个字节
uint32_t magic;
uint32_t num_items;
uint32_t num_labels;
uint32_t rows;
uint32_t cols;
//读取图片数据结构
//image的维度为4(magic,num_items,width,height)
//label的维度为2(magic,num_labels)
image_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
label_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2049) << "Incorrect label file magic.";
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items);
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
num_labels = swap_endian(num_labels);
CHECK_EQ(num_items, num_labels);
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows);
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols);
//定义lmdb和leveldb类的变量
MDB_env *mdb_env;
MDB_dbi mdb_dbi;
MDB_val mdb_key, mdb_data;
MDB_txn *mdb_txn;
leveldb::DB* db;
leveldb::Options options;
options.error_if_exists = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456;
leveldb::WriteBatch* batch = NULL;
//open the files
if (db_backend == "leveldb") {
// leveldb
LOG(INFO) << "Opening leveldb " << db_path;
leveldb::Status status = leveldb::DB::Open(options, db_path, &db);
CHECK(status.ok()) << "Failed to open leveldb " << db_path<< ". Is it already existing?";
batch = new leveldb::WriteBatch();
//Storing to db
char label;
char* pixels = new char[rows * cols];
int count = 0;
string value;
//define the detum
Datum datum;
datum.set_channels(1);
datum.set_height(rows);
datum.set_width(cols);
LOG(INFO) << "A total of " << num_items << " items.";
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
//read the files and assign to "datum"
for (int item_id = 0; item_id < num_items; ++item_id)
{
image_file.read(pixels, rows * cols);
label_file.read(&label, 1);
datum.set_data(pixels, rows*cols);
datum.set_label(label);
string key_str = caffe::format_int(item_id, 8);
datum.SerializeToString(&value);
txn->Put(key_str, value);
//write to the batch
if (++count % 1000 == 0)
{
txn->Commit();
}
}
//write the last batch
if (count % 1000 != 0)
{
txn->Commit();
}
LOG(INFO) << "Processed " << count << " files.";
delete[] pixels;
db->Close();
}
main函数(主函数代码):
int main(int argc, char** argv)
{
#ifndef GFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
FLAGS_alsologtostderr = 1; //获取--backend=${BACKEND}参数
gflags::SetUsageMessage("This script converts the MNIST dataset to\n"
"the lmdb/leveldb format used by Caffe to load data.\n"
"Usage:\n"
"convert_mnist_data [FLAGS] input_image_file input_label_file "
"output_db_file\n"
"The MNIST dataset could be downloaded at\n"
"http://yann.lecun.com/exdb/mnist/\n"
"You should gunzip them after downloading,"
"or directly use data/mnist/get_mnist.sh\n");
gflags::ParseCommandLineFlags(&argc, &argv, true);
const string& db_backend = FLAGS_backend;
if (argc != 4)
{
gflags::ShowUsageWithFlagsRestrict(argv[0],"examples/mnist/convert_mnist_data");
}
else
{
google::InitGoogleLogging(argv[0]);
convert_dataset(argv[1], argv[2], argv[3], db_backend);
}
return 0;
}
#else
int main(int argc, char** argv)
{
LOG(FATAL) << "This example requires LevelDB and LMDB; " << "compile with USE_LEVELDB and USE_LMDB.";
}