Caffe 源码阅读笔记 [数据读入和处理] DataReader和DataTransformer

概述

这一篇主要阅读Caffe如何从数据源读取训练数据和如何对训练数据进行变换。

DataReader

DataReader通过LayerParameter获取配置信息。一个训练数据集可以读取到多个GPU(solver)进行训练。为了保证数据分布均匀,每个训练数据集只分配一个线程进行读取,这个线程把数据通过round-robin的方式分发给不同的GPU(solver),每个solver获得的数据是唯一的,每个solver有独立的数据缓存存储属于他们的数据,数据读取线程的工作是把数据从数据集里读取出来放到solver对应的缓存里面。

Datum是什么?

我在DataReader里看到很多Datum有关的代码,但搜索了所有的cpp和hpp都没有发现它的定义,后来发现原来是在caffe.proto里面定义了,看起来是一个表示图片数据的类型

message Datum {
  optional int32 channels = 1;
  optional int32 height = 2;
  optional int32 width = 3;
  // the actual image data, in bytes
  optional bytes data = 4;
  optional int32 label = 5;
  // Optionally, the datum could also hold float data.
  repeated float float_data = 6;
  // If true data contains an encoded image that need to be decoded
  optional bool encoded = 7 [default = false];
}

DataReader成员变量

// 包含两个BlockingQueue<Datum*>:free_和full_。free_是空闲指针队列,当我们获取一个数据Datum*后,我们把它从free_移到full_队列里。
shared_ptr<QueuePair> queue_pair_;
// 包含一个后台线程读取数据,和一个QueuePair的队列。
shared_ptr<Body> body_;
// 每个数据库对应一个Body,内包含一个读取数据库的线程
static map<const string, boost::weak_ptr<DataReader::Body> > bodies_;

DataReader初始化

DataReader::DataReader(const LayerParameter& param)
    : queue_pair_(new QueuePair(  
        // 创建预读取队列,往free_队列插入prefetch*batch_size个new Datum()指针
        param.data_param().prefetch() * param.data_param().batch_size())) {
  boost::mutex::scoped_lock lock(bodies_mutex_);
  string key = source_key(param); // param.name() + ":" + param.data_param().source()
  weak_ptr<Body>& weak = bodies_[key];
  body_ = weak.lock();
  if (!body_) {
    // 创建新的body并创建key->body的对应关系
    body_.reset(new Body(param));
    bodies_[key] = weak_ptr<Body>(body_);
  }
  // 把这个queue_pair_加入body_的队列里
  body_->new_queue_pairs_.push(queue_pair_);
}

从数据源读取数据

new Body(param)会起来一个InternalThread进行数据的读取。InternalThread会进而调用InternalThreadEntry进行操作

void DataReader::Body::InternalThreadEntry() {
  // 连接Database
  shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend()));
  db->Open(param_.data_param().source(), db::READ);
  shared_ptr<db::Cursor> cursor(db->NewCursor());
  // 获得当前solver的个数
  int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1;
  // 每个solver对应一个QueuePair队列,数据按round-robin的形式分配给每个solver
  while (!must_stop()) {
    for (int i = 0; i < solver_count; ++i) {
      read_one(cursor.get(), new_queue_pairs_[i].get());
    }
  }
}
// 从cursor中读取一个Datum
void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) {
  // 从free_队列pop一个空Datum
  Datum* datum = qp->free_.pop();
  // 把当前cursor的值拷贝到datum里
  datum->ParseFromString(cursor->value());
  // 把datum移到full_队列
  qp->full_.push(datum);
  // 移动指针
  cursor->Next();
  // 如果cursor指向不合法的位置,移到初始点重新读取。
  if (!cursor->valid()) {
    DLOG(INFO) << "Restarting data prefetching from start.";
    cursor->SeekToFirst();
  }
}

DataTransformer

数据读取出来后需要做数据的变换操作,比如scaling, mirroring, substracting the image mean,…DataTransformer通过TransformationParameter来配置

// 返回[0, n-1]之间的一个随机数
int Rand(int n)
// 对数据Datum进行变换,transformed_blob是变换结果,Transform有多个不同的实现,但逻辑都是大同小异的。
void Transform(const Datum& datum, Blob<Dtype>* transformed_blob) {
  如果强制要有彩色或者强制要有灰色:
     // 使用openCV
     cv_img = DecodeDatumToCVMat(datum, param_.force_color());
  如果剪切图片crop_size > 0,我们要剪出来一个crop_size*crop_size大小的图片:
     // blob需要做reshape(input_num, input_channels, crop_size, crop_size)
     如果是训练阶段,则随机选一个剪切位置的左上角:
         h_off = Rand(datum_height - crop_size + 1);
         w_off = Rand(datum_width - crop_size + 1);
     否则从中间剪切:
         h_off = (datum_height - crop_size) / 2;
         w_off = (datum_width - crop_size) / 2;
  如果做镜像do_mirror:
     data_index = (c * height + h) * width + w;
     top_index = (c * height + h) * width + (width - 1 - w);
     transformed_data[top_index] = datum[data_index];
  如果scale > 1:
     transformed_data[data_index] = datum[data_index] * scale;
  如果有要diff的文件has_mean_file,mean file数据存在mean数组里:
     transformed_data[data_index] = datum[data_index] - mean[data_index];
  如果有平均值要减去has_mean_values:
     // 注意每个channel都可以有一个不同的均值
     transformed_data[data_index] = datum[data_index] - mean_values_[channel];
}
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值