Caffe2 IO
本文主要记录下我对Caffe2的输入输出部分源代码的理解。数据是以什么样的形式输入进网络的,训练过程中如何保存网络模型。与数据输入相关的Operator是DBReader, ImageInputOp, 与存储训练过程中保存模型相关信息的是SaveOp, LoadOp,以及一系列与序列化相关的工具类,比如BlobSerializer。下面分别介绍一下,如有理解错误,欢迎指出。PS,Caffe2的代码写得真心赞啊。
- DBReader
- ImageInputOp
- SaveOp
- LoadOp
- 总结
DBReader
如同Caffe1一样,一般情况下,在进行模型训练的时候,Caffe2也需要事先将数据转成特定格式的数据库,比如lmdb, leveldb。只不过Caffe2支持的数据库格式更加丰富,除了上述两种格式的db外,还有minidb, zmqdb, protodb, rocksdb等等。Caffe2中对lmdb的实现跟Caffe1有所不同,但功能是一样的。PS,个人以为Caffe1中的实现要优雅些,因为我直接在windows上用Caffe2自带的lmdb.cc来生成数据库时运行不通过,直接改成Caffe1中的就OK了。另外由于Caffe2在默认保存模型时候使用的是minidb, 所以简单地介绍下minidb。
DBReader封装了如何读取数据库的操作。注意在单机多GPU情况下DBReader只有一个实例,为各个GPU共享。在多机的情况下,每台机器有一个DBReader实例,通过DBReader中的成员变量shard_id_来标识该节点负责读取哪一部分的数据库。通常,每一台机器都会有一份完整的相同的数据库,当然也可以通过nfs将数据库从一台机器映射给其他机器。读取同一个数据库的时候。DBReader自动会对数据进行切片,保证每个节点的每个GPU读取数据库的不同部分,以此达到数据并行。DBReader的摘要如下:
class DBReader {
...
private:
string db_type_; //数据库的类型,包括minidb,leveldb,lmdb等等
string source_; //数据库的路径
unique_ptr<DB> db_; //数据库对象
unique_ptr<Cursor> cursor_; //数据库游标
mutable std::mutex reader_mutex_;//单机多GPU环境下,应该是多线程进行训练,多线程共享同一个DBReader实例,因此需要用这个reader_mutex来控制对共享变量的访问。
uint32_t num_shards_; //单机环境下,该值为0,分布式环境下,该值为节点数目。
uint32_t shard_id_; //节点id,从0开始,单机情况下为0,依次递增,
DISABLE_COPY_AND_ASSIGN(DBReader);
public:
void Open(const string& db_type, const string& source, const int32_t num_shards = 1, const int32_t shard_id = 0) { //打开数据库,该函数会在构造函数里被调用
cursor_.reset();
db_.reset();
db_type_ = db_type;
source_ = source;
db_ = CreateDB(db_type_, source_, READ);
CAFFE_ENFORCE(db_, "Cannot open db: ", source_, " of type ", db_type_);
InitializeCursor(num_shards, shard_id);
}
// for i = 0: batch_size, call Read
void Read(string* key, string* value) const {
CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
std::unique_lock<std::mutex> mutex_lock(reader_mutex_);//这里注意,只对单机多GPU会阻塞,不同机器之间不会阻塞,因为是不同的DBReader实例,多机通信会通过rendezvous进行同步,比如redis _store_handler等。
*key = cursor_->key();
*value = cursor_->value();
// 在分布式环境下,由于一次有num_shards台机器参与读取数据,因此一次计算读取的数据量有num_shards * 每台机器读取的数据量,所以对于每一台机器而言,这里要跳过num_shards个记录,才是它下一次迭代应该读取的数据库位置
for (int s = 0; s < num_shards_; s++) {
cursor_->Next();
if (!cursor_->Valid()) {
MoveToBeginning();
break;
}
}
}
...
};
DB, Transaction, Cursor三个接口类定义了如何操作数据库。对于不同类型的数据库,会有相应的实现,比如针对lmdb,就有LMDB, LMDBTransaction, LMDBCursor,针对minidb,就有MiniDB, MiniDBTransaction, MiniDBCursor。从Caffe2中实现的lmdb,minidb, leveldb来看,读数据库只支持顺序读取,即cursor从头到尾顺序访问数据库,当访问到数据库末尾时候,cursor又从头开始,因此并不支持对数据库的随机访问。DB的摘要如下:
class DB {
public:
DB(const string& /*source*/, Mode mode) : mode_(mode) {}
virtual ~DB() { }
/**
* Closes the database.
*/
virtual void Close() = 0;
/**
* Returns a cursor to read the database. The caller takes the ownership of
* the pointer.
*/
virtual std::unique_ptr<Cursor> NewCursor() = 0;
/**
* Returns a transaction to write data to the database. The caller takes the
* ownership of the pointer.
*/
virtual std::unique_ptr<Transaction> NewTransaction() = 0;
protected:
Mode mode_; //这个mode定义为enum Mode { READ, WRITE, NEW };
DISABLE_COPY_AND_ASSIGN(DB);
};
minidb相关操作
minidb其实就是简单地封装了C语言中的文件IO调用, 没啥特别之处,直接把caffe2/core/db.cc中的代码贴出来。因为有这个minidb的存在,因此Caffe2就不像Caffe1中有辣么多依赖软件了。lmdb和leveldb对Caffe2来说就是可选的了。不过,minidb的功能肯定不如lmdb了(个人猜测,minidb的读写效率啊,估计也没有lmdb高)。
class MiniDBCursor : public Cursor {
public:
explicit MiniDBCursor(FILE* f, std::mutex* mutex)
: file_(f), lock_(*mutex), valid_(true) {
// We call Next() to read in the first entry.
Next();
}
~MiniDBCursor() {}
void Seek(const string& /*key*/) override {
LOG(FATAL) << "MiniDB does not support seeking to a specific key.";
}
void SeekToFirst() override {
fseek(file_, 0, SEEK_SET);
CAFFE_ENFORCE(!feof(file_), "Hmm, empty file?");
// Read the first item.
valid_ = true;
Next();
}
void Next() override {
// First, read in the k