梳理caffe代码data_reader(十一)

本文详细介绍了Caffe框架中的数据读取机制,包括利用阻塞队列进行数据的高效安全传输过程。通过单线程从数据源读取数据,并将其放入阻塞队列中供数据层使用,确保了多GPU训练时数据的有序性和确定性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

上一篇的blocking_queue到底干了一件什么事情呢?刚刚看完就有点忘记了,再过一会估计忘光了。。。

顾名思义,阻塞队列,就是一个正在排队的打饭队列,先到窗口的先打饭,为什么会高效安全呢?一是像交通有秩序,二是有了秩序是不是交通运行起来就快了。

我们就看看数据是怎么进行排队的?

头文件:

#ifndef CAFFE_DATA_READER_HPP_
#define CAFFE_DATA_READER_HPP_

#include <map>
#include <string>
#include <vector>

#include "caffe/common.hpp"
#include "caffe/internal_thread.hpp"
#include "caffe/util/blocking_queue.hpp"
#include "caffe/util/db.hpp"

namespace caffe {

/**
 * @brief Reads data from a source to queues available to data layers.
 * A single reading thread is created per source, even if multiple solvers
 * are running in parallel, e.g. for multi-GPU training. This makes sure
 * databases are read sequentially, and that each solver accesses a different
 * subset of the database. Data is distributed to solvers in a round-robin
 * way to keep parallel training deterministic.
 */
/*
从共享的资源读取数据然后排队输入到数据层,每个资源创建单个线程,即便是使用多个GPU在并行任务中求解。这就保证对于频繁读取数据库,并且每个求解的线程使用的子数据是不同的。数据成功设计就是这样使在求解时数据保持一种循环地并行训练。
*/
class DataReader {
 public:
  explicit DataReader(const LayerParameter& param);
  ~DataReader();
//
  inline BlockingQueue<Datum*>& free() const {
    return queue_pair_->free_;
  }
  inline BlockingQueue<Datum*>& full() const {
    return queue_pair_->full_;
  }

 protected:
  // Queue pairs are shared between a body and its readers
  class QueuePair {
   public:
    explicit QueuePair(int size);
    ~QueuePair();
//定义了两个阻塞队列free_和full_
    BlockingQueue<Datum*> free_;
    BlockingQueue<Datum*> full_;

  DISABLE_COPY_AND_ASSIGN(QueuePair);
  };

  // A single body is created per source
//继承InternalThread 这个类的
  class Body : public InternalThread {
   public:
    explicit Body(const LayerParameter& param);
    virtual ~Body();

   protected:
//重写了InternalThread内部的InternalThreadEntry函数,此外还添加了read_one函数
    void InternalThreadEntry();
    void read_one(db::Cursor* cursor, QueuePair* qp);

    const LayerParameter param_;
    BlockingQueue<shared_ptr<QueuePair> > new_queue_pairs_;
//内部有DataReader的友元
    friend class DataReader;

  DISABLE_COPY_AND_ASSIGN(Body);
  };

  // A source is uniquely identified by its layer name + path, in case
  // the same database is read from two different locations in the net.
  static inline string source_key(const LayerParameter& param) {
    return param.name() + ":" + param.data_param().source();
  }

  const shared_ptr<QueuePair> queue_pair_;
  shared_ptr<Body> body_;

  static map<const string, boost::weak_ptr<DataReader::Body> > bodies_;

DISABLE_COPY_AND_ASSIGN(DataReader);
};

}  // namespace caffe

#endif  // CAFFE_DATA_READER_HPP_
实现部分:
#include <boost/thread.hpp>
#include <map>
#include <string>
#include <vector>

#include "caffe/common.hpp"
#include "caffe/data_reader.hpp"
#include "caffe/layers/data_layer.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {

using boost::weak_ptr;

map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_;
static boost::mutex bodies_mutex_;

DataReader::DataReader(const LayerParameter& param)
    : queue_pair_(new QueuePair(  //
        param.data_param().prefetch() * param.data_param().batch_size())) {
  // Get or create a body
  boost::mutex::scoped_lock lock(bodies_mutex_);
  string key = source_key(param);
  weak_ptr<Body>& weak = bodies_[key];
  body_ = weak.lock();
  if (!body_) {
    body_.reset(new Body(param));
    bodies_[key] = weak_ptr<Body>(body_);
  }
  body_->new_queue_pairs_.push(queue_pair_);
}

DataReader::~DataReader() {
  string key = source_key(body_->param_);
  body_.reset();
  boost::mutex::scoped_lock lock(bodies_mutex_);
  if (bodies_[key].expired()) {
    bodies_.erase(key);
  }
}

//根据给定的size初始化的若干个Datum的实例到free里面

DataReader::QueuePair::QueuePair(int size) {
  // Initialize the free queue with requested number of datums
  for (int i = 0; i < size; ++i) {
    free_.push(new Datum());
  }
}
//将full_和free_这两个队列里面的Datum对象全部delete。
DataReader::QueuePair::~QueuePair() {
  Datum* datum;
  while (free_.try_pop(&datum)) {
    delete datum;
  }
  while (full_.try_pop(&datum)) {
    delete datum;
  }
}
//Body类的构造函数,实际上是给定网络的参数,然后开始启动内部线程
DataReader::Body::Body(const LayerParameter& param)
    : param_(param),
      new_queue_pairs_() {
  StartInternalThread();// 调用InternalThread内部的函数来初始化运行环境以及新建线程去执行虚函数InternalThreadEntry的内容
}
// 析构,停止线程
DataReader::Body::~Body() {
  StopInternalThread();
}

// 自己实现的需要执行的函数
// 首先打开数据库,然后设置游标,然后设置QueuePair指针容器
void DataReader::Body::InternalThreadEntry() {
  // 获取所给定的数据源的类型来得到DB的指针
  shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend()));
  // 从网络参数中给定的DB的位置打开DB
  db->Open(param_.data_param().source(), db::READ);
  // 新建游标指针
  shared_ptr<db::Cursor> cursor(db->NewCursor());
  // 新建QueuePair指针容器,QueuePair里面包含了free_和full_这两个阻塞队列
  vector<shared_ptr<QueuePair> > qps;
  try {
    // 根据网络参数的阶段来设置solver_count
    int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1;

    // To ensure deterministic runs, only start running once all solvers
    // are ready. But solvers need to peek on one item during initialization,
    // so read one item, then wait for the next solver.
    for (int i = 0; i < solver_count; ++i) {
      shared_ptr<QueuePair> qp(new_queue_pairs_.pop());
      read_one(cursor.get(), qp.get());// 读取一个数据
      qps.push_back(qp);压入
    }
    // Main loop
    while (!must_stop()) {
      for (int i = 0; i < solver_count; ++i) {
        read_one(cursor.get(), qps[i].get());
      }
      // Check no additional readers have been created. This can happen if
      // more than one net is trained at a time per process, whether single
      // or multi solver. It might also happen if two data layers have same
      // name and same source.
      CHECK_EQ(new_queue_pairs_.size(), 0);
    }
  } catch (boost::thread_interrupted&) {
    // Interrupted exception is expected on shutdown
  }
}

// 从数据库中获取一个数据
void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) {
  // 从QueuePair中的free_队列pop出一个
  Datum* datum = qp->free_.pop();
  // TODO deserialize in-place instead of copy?
  // 然后解析cursor中的值
  datum->ParseFromString(cursor->value());
  // 然后压入QueuePair中的full_队列
  qp->full_.push(datum);

  // go to the next iter
  // 游标指向下一个
  cursor->Next();
  if (!cursor->valid()) {
    DLOG(INFO) << "Restarting data prefetching from start.";
    cursor->SeekToFirst();// 如果游标指向的位置已经无效了则指向第一个位置
  }
}

}  // namespace caffe
数据层就是调用了封装层的DB来读取数据,此外还简单封装了boost的线程库,然后自己封装了个阻塞队列。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值