上一篇的blocking_queue到底干了一件什么事情呢?刚刚看完就有点忘记了,再过一会估计忘光了。。。
顾名思义,阻塞队列,就是一个正在排队的打饭队列,先到窗口的先打饭,为什么会高效安全呢?一是像交通有秩序,二是有了秩序是不是交通运行起来就快了。
我们就看看数据是怎么进行排队的?
头文件:
#ifndef CAFFE_DATA_READER_HPP_
#define CAFFE_DATA_READER_HPP_
#include <map>
#include <string>
#include <vector>
#include "caffe/common.hpp"
#include "caffe/internal_thread.hpp"
#include "caffe/util/blocking_queue.hpp"
#include "caffe/util/db.hpp"
namespace caffe {
/**
* @brief Reads data from a source to queues available to data layers.
* A single reading thread is created per source, even if multiple solvers
* are running in parallel, e.g. for multi-GPU training. This makes sure
* databases are read sequentially, and that each solver accesses a different
* subset of the database. Data is distributed to solvers in a round-robin
* way to keep parallel training deterministic.
*/
/*
从共享的资源读取数据然后排队输入到数据层,每个资源创建单个线程,即便是使用多个GPU在并行任务中求解。这就保证对于频繁读取数据库,并且每个求解的线程使用的子数据是不同的。数据成功设计就是这样使在求解时数据保持一种循环地并行训练。
*/
class DataReader {
public:
explicit DataReader(const LayerParameter& param);
~DataReader();
//
inline BlockingQueue<Datum*>& free() const {
return queue_pair_->free_;
}
inline BlockingQueue<Datum*>& full() const {
return queue_pair_->full_;
}
protected:
// Queue pairs are shared between a body and its readers
class QueuePair {
public:
explicit QueuePair(int size);
~QueuePair();
//定义了两个阻塞队列free_和full_
BlockingQueue<Datum*> free_;
BlockingQueue<Datum*> full_;
DISABLE_COPY_AND_ASSIGN(QueuePair);
};
// A single body is created per source
//继承InternalThread 这个类的
class Body : public InternalThread {
public:
explicit Body(const LayerParameter& param);
virtual ~Body();
protected:
//重写了InternalThread内部的InternalThreadEntry函数,此外还添加了read_one函数
void InternalThreadEntry();
void read_one(db::Cursor* cursor, QueuePair* qp);
const LayerParameter param_;
BlockingQueue<shared_ptr<QueuePair> > new_queue_pairs_;
//内部有DataReader的友元
friend class DataReader;
DISABLE_COPY_AND_ASSIGN(Body);
};
// A source is uniquely identified by its layer name + path, in case
// the same database is read from two different locations in the net.
static inline string source_key(const LayerParameter& param) {
return param.name() + ":" + param.data_param().source();
}
const shared_ptr<QueuePair> queue_pair_;
shared_ptr<Body> body_;
static map<const string, boost::weak_ptr<DataReader::Body> > bodies_;
DISABLE_COPY_AND_ASSIGN(DataReader);
};
} // namespace caffe
#endif // CAFFE_DATA_READER_HPP_
实现部分:
#include <boost/thread.hpp>
#include <map>
#include <string>
#include <vector>
#include "caffe/common.hpp"
#include "caffe/data_reader.hpp"
#include "caffe/layers/data_layer.hpp"
#include "caffe/proto/caffe.pb.h"
namespace caffe {
using boost::weak_ptr;
map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_;
static boost::mutex bodies_mutex_;
DataReader::DataReader(const LayerParameter& param)
: queue_pair_(new QueuePair( //
param.data_param().prefetch() * param.data_param().batch_size())) {
// Get or create a body
boost::mutex::scoped_lock lock(bodies_mutex_);
string key = source_key(param);
weak_ptr<Body>& weak = bodies_[key];
body_ = weak.lock();
if (!body_) {
body_.reset(new Body(param));
bodies_[key] = weak_ptr<Body>(body_);
}
body_->new_queue_pairs_.push(queue_pair_);
}
DataReader::~DataReader() {
string key = source_key(body_->param_);
body_.reset();
boost::mutex::scoped_lock lock(bodies_mutex_);
if (bodies_[key].expired()) {
bodies_.erase(key);
}
}
//根据给定的size初始化的若干个Datum的实例到free里面
DataReader::QueuePair::QueuePair(int size) {
// Initialize the free queue with requested number of datums
for (int i = 0; i < size; ++i) {
free_.push(new Datum());
}
}
//将full_和free_这两个队列里面的Datum对象全部delete。
DataReader::QueuePair::~QueuePair() {
Datum* datum;
while (free_.try_pop(&datum)) {
delete datum;
}
while (full_.try_pop(&datum)) {
delete datum;
}
}
//Body类的构造函数,实际上是给定网络的参数,然后开始启动内部线程
DataReader::Body::Body(const LayerParameter& param)
: param_(param),
new_queue_pairs_() {
StartInternalThread();// 调用InternalThread内部的函数来初始化运行环境以及新建线程去执行虚函数InternalThreadEntry的内容
}
// 析构,停止线程
DataReader::Body::~Body() {
StopInternalThread();
}
// 自己实现的需要执行的函数
// 首先打开数据库,然后设置游标,然后设置QueuePair指针容器
void DataReader::Body::InternalThreadEntry() {
// 获取所给定的数据源的类型来得到DB的指针
shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend()));
// 从网络参数中给定的DB的位置打开DB
db->Open(param_.data_param().source(), db::READ);
// 新建游标指针
shared_ptr<db::Cursor> cursor(db->NewCursor());
// 新建QueuePair指针容器,QueuePair里面包含了free_和full_这两个阻塞队列
vector<shared_ptr<QueuePair> > qps;
try {
// 根据网络参数的阶段来设置solver_count
int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1;
// To ensure deterministic runs, only start running once all solvers
// are ready. But solvers need to peek on one item during initialization,
// so read one item, then wait for the next solver.
for (int i = 0; i < solver_count; ++i) {
shared_ptr<QueuePair> qp(new_queue_pairs_.pop());
read_one(cursor.get(), qp.get());// 读取一个数据
qps.push_back(qp);压入
}
// Main loop
while (!must_stop()) {
for (int i = 0; i < solver_count; ++i) {
read_one(cursor.get(), qps[i].get());
}
// Check no additional readers have been created. This can happen if
// more than one net is trained at a time per process, whether single
// or multi solver. It might also happen if two data layers have same
// name and same source.
CHECK_EQ(new_queue_pairs_.size(), 0);
}
} catch (boost::thread_interrupted&) {
// Interrupted exception is expected on shutdown
}
}
// 从数据库中获取一个数据
void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) {
// 从QueuePair中的free_队列pop出一个
Datum* datum = qp->free_.pop();
// TODO deserialize in-place instead of copy?
// 然后解析cursor中的值
datum->ParseFromString(cursor->value());
// 然后压入QueuePair中的full_队列
qp->full_.push(datum);
// go to the next iter
// 游标指向下一个
cursor->Next();
if (!cursor->valid()) {
DLOG(INFO) << "Restarting data prefetching from start.";
cursor->SeekToFirst();// 如果游标指向的位置已经无效了则指向第一个位置
}
}
} // namespace caffe
数据层就是调用了封装层的DB来读取数据,此外还简单封装了boost的线程库,然后自己封装了个阻塞队列。