DeepLearning(基于caffe)实战项目(1)--mnist_convert函数分析

        搞了这么长时间DeepLearning,打算用回忆的方式,进行一下总结。我们知道,caffe只能识别leveldb或者lmdb格式的文件,所以一切从数据转换开始。若想自己写转换函数程序(matlab/Python),自然而然需要读懂caffe中examples里转换的函数。

下面是mnist_convert.cpp的程序:

/***********************************************************************************************************************************

TIPS:caffe为什么采用lmdb或者leveldb,而不是直接读取原始数据呢?

一方面,数据类型五花八门,种类繁多,不可能用一套代码实现所有类型的输入数据要求,转换为统一格式可以简化数据读取层的实现;另一方面,使用leveldb或者lmdb可以提高磁盘IO利用率。

/************************************************************************************************************************************

引用相应的文件和命名空间:

//
//该程序将mnist数据集转换为caffe需要的格式(lmdb)
//用法:mnist_convert_data input_folder output_db_file
#include <gflags/gflags.h>     //gflags命令行参数解析的头文件
#include <glog/logging.h>      //记录程序日志的glog头文件
#include <google/protobuf/text_format.h>      //解析proto类型文件中,解析prototxt类型的头文件
#if defined(USE_LEVELDB) && defined(USE_LMDB)      
#include <leveldb/db.h>         //引入leveldb类型数据头文件
#include <leveldb/write_batch.h>     //引入leveldb类型数据写入头文件
#include <lmdb.h>
#endif
#if defined(_MSC_VER)
#include <direct.h>
#define mkdir(X, Y) _mkdir(X)
#endif
#include <stdint.h>
#include <sys/stat.h>
#include <fstream>  //NOLINT(readability/streams)
#include <string>
#include "boost/scoped_ptr.hpp"
#include "caffe/proto/caffe.pb.h"      //解析caffe中proto类型文件的头文件
#include "caffe/util/db.hpp"
#include "caffe/util/format.hpp"
#if defined(USE_LEVELDB) && defined(USE_LMDB)

using namespace caffe;  //NOLINT(build/namespaces)
using boost::scoped_ptr;
using std::string;

定义backend(程序变)

//在程序调用时,铜鼓--backend=${BACKEND}来给变量命名
DEFINE_string(backend, "lmdb", "The backend for storing the result");    //GFLAGS工具定义明星行选项backend,默认是lmdb

大端字节存储的二进制文件与小端字节存储的二进制文件转换

/****************************************************************************************************************************************************************************************

TIPS:为何需要两种二进制文件转换?

大小端字节的计算机存储的二进制文件格式不同,大端计算机无法读取小端计算机存储的二进制文件(小端一样)所以需要两种文件的转换。

/*****************************************************************************************************************************************************************************************

uint32_t swap_endian(uint32_t val)
 {
    val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
    return (val << 16) | (val >> 16);
}

convert_dataset函数(核心代码)

void convert_dataset(const char* image_filename, const char* label_filename,const char* db_path, const string& db_backend) 
{
//打开(二进制)文件
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);

//CHECK用于检测文件是否正常打开
CHECK(image_file) << "Unable to open file " << image_filename;
CHECK(label_file) << "Unable to open file " << label_filename;

//根据mnist图像结构,定义长、宽、样本数、标签数
//uint32_t是自定义数据类型,unsigned int 32是指每个int32整数占用4个字节
uint32_t magic;
uint32_t num_items;
uint32_t num_labels;
uint32_t rows;
uint32_t cols;

//读取图片数据结构
//image的维度为4(magic,num_items,width,height)
//label的维度为2(magic,num_labels)
image_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
label_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2049) << "Incorrect label file magic.";
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items);
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
num_labels = swap_endian(num_labels);
CHECK_EQ(num_items, num_labels);
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows);
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols);

//定义lmdb和leveldb类的变量
MDB_env *mdb_env;  
MDB_dbi mdb_dbi;  
MDB_val mdb_key, mdb_data;  
MDB_txn *mdb_txn;  
leveldb::DB* db;
leveldb::Options options;   
options.error_if_exists = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456; 
leveldb::WriteBatch* batch = NULL;
//open the files
if (db_backend == "leveldb") {  
    // leveldb
    LOG(INFO) << "Opening leveldb " << db_path;
    leveldb::Status status = leveldb::DB::Open(options, db_path, &db);
    CHECK(status.ok()) << "Failed to open leveldb " << db_path<< ". Is it already existing?";
    batch = new leveldb::WriteBatch();
//Storing to db
     char label;
    char* pixels = new char[rows * cols];
    int count = 0;
    string value;
//define the detum
    Datum datum;
    datum.set_channels(1);
    datum.set_height(rows);
    datum.set_width(cols);
    LOG(INFO) << "A total of " << num_items << " items.";
    LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
//read the files and assign to "datum"
    for (int item_id = 0; item_id < num_items; ++item_id) 
  {
    image_file.read(pixels, rows * cols);
    label_file.read(&label, 1);
    datum.set_data(pixels, rows*cols);
    datum.set_label(label);
    string key_str = caffe::format_int(item_id, 8);
    datum.SerializeToString(&value);
    txn->Put(key_str, value);
//write to the batch
    if (++count % 1000 == 0) 
{
      txn->Commit();
}
  }
//write the last batch
    if (count % 1000 != 0)
 {
      txn->Commit();
  }
    LOG(INFO) << "Processed " << count << " files.";
    delete[] pixels;
    db->Close();
}

main函数(主函数代码)

int main(int argc, char** argv) 
{
#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  FLAGS_alsologtostderr = 1;      //获取--backend=${BACKEND}参数

  gflags::SetUsageMessage("This script converts the MNIST dataset to\n"
        "the lmdb/leveldb format used by Caffe to load data.\n"
        "Usage:\n"
        "convert_mnist_data [FLAGS] input_image_file input_label_file "
        "output_db_file\n"
        "The MNIST dataset could be downloaded at\n"
        "http://yann.lecun.com/exdb/mnist/\n"
        "You should gunzip them after downloading,"
        "or directly use data/mnist/get_mnist.sh\n");
  gflags::ParseCommandLineFlags(&argc, &argv, true);
  const string& db_backend = FLAGS_backend;
  if (argc != 4) 
{
    gflags::ShowUsageWithFlagsRestrict(argv[0],"examples/mnist/convert_mnist_data");
  }
 else
 {
    google::InitGoogleLogging(argv[0]);
    convert_dataset(argv[1], argv[2], argv[3], db_backend);
  }
  return 0;
}
#else
int main(int argc, char** argv)
 {
  LOG(FATAL) << "This example requires LevelDB and LMDB; " << "compile with USE_LEVELDB and USE_LMDB.";
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值