Caffe下LMDB与H5数据读取代码实现分析

1. 前言

在之前的文章中讲到了Caffe中是如何把图像数据转换成为LMDB与H5格式文件的。那么Caffe中是怎么实现由这些文件读取到网络中进行训练的呢?其实Caffe中是有专门的数据读取层的,用来读取不同的数据类型。下面是Caffe中主要数据读取类的关系图:
在这里插入图片描述
平时用得比较多的是DataLayer与ImageDataLayer(读取图像效率低也不怎么使用)。这里就主要从DataLayer进行分析。至于H5文件的读取是通过另外一个类实现的,这个在后面的讲解中说到。

2. LMDB读取

对于读取LMDB数据类型使用的DataLayer类,其中通过backend参数指定了数据的类型,并在构造函数里面就对其进行了初始化

template <typename Dtype>
DataLayer<Dtype>::DataLayer(const LayerParameter& param)
  : BasePrefetchingDataLayer<Dtype>(param),
    offset_() {
  db_.reset(db::GetDB(param.data_param().backend())); //通过网络参数得到DB的类型,并进行初始化
  db_->Open(param.data_param().source(), db::READ); //打开DB对象
  cursor_.reset(db_->NewCursor());
}

接下来在DataLayerSetUp函数中读取一个数据来初始化prefetch中数据存储单元与当前层的输出blob的维度,后序再在线程函数中去读取训练数据

template <typename Dtype>
void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
  const int batch_size = this->layer_param_.data_param().batch_size(); //当前网络设置的batch size大小
  // Read a data point, and use it to initialize the top blob. 在数据库中读取一个datum去初始化top blob
  Datum datum;
  datum.ParseFromString(cursor_->value());

  // Use data_transformer to infer the expected blob shape from datum.
  // 根据datum的维度信息来设置top blob的C*W*H,要是有Corp参数需要按照Corp的参数来
  vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
  this->transformed_data_.Reshape(top_shape);
  // Reshape top[0] and prefetch_data according to the batch_size.
  top_shape[0] = batch_size; //设置top blob的batch size
  top[0]->Reshape(top_shape);
  // 根据网络prefetch的设置来预先读取数据
  for (int i = 0; i < this->prefetch_.size(); ++i) {
    this->prefetch_[i]->data_.Reshape(top_shape);
  }
  LOG_IF(INFO, Caffe::root_solver())
      << "output data size: " << top[0]->num() << ","
      << top[0]->channels() << "," << top[0]->height() << ","
      << top[0]->width();
  // label 对应label的维度设置
  if (this->output_labels_) {
    vector<int> label_shape(1, batch_size);
    top[1]->Reshape(label_shape);
    for (int i = 0; i < this->prefetch_.size(); ++i) {
      this->prefetch_[i]->label_.Reshape(label_shape);
    }
  }
}

线程读取函数,用于读取一个batch的数据

// 这个函数由prefetch队列的线程调用,用于加载batch数据
template<typename Dtype>
void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
  CPUTimer batch_timer;
  batch_timer.Start();
  double read_time = 0;
  double trans_time = 0;
  CPUTimer timer;
  CHECK(batch->data_.count());
  CHECK(this->transformed_data_.count());
  const int batch_size = this->layer_param_.data_param().batch_size();

  Datum datum;
  for (int item_id = 0; item_id < batch_size; ++item_id) {
    timer.Start();
    while (Skip()) {
      Next();
    }
    datum.ParseFromString(cursor_->value());
    read_time += timer.MicroSeconds();

    if (item_id == 0) { //后序的数据都是按照第一个数据的维度来确定
      // Reshape according to the first datum of each batch
      // on single input batches allows for inputs of varying dimension.
      // Use data_transformer to infer the expected blob shape from datum.
      vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
      this->transformed_data_.Reshape(top_shape);
      // Reshape batch according to the batch_size.
      top_shape[0] = batch_size;
      batch->data_.Reshape(top_shape);
    }

    // Apply data transformations (mirror, scale, crop...) 使用图像变换操作
    timer.Start();
    int offset = batch->data_.offset(item_id);
    Dtype* top_data = batch->data_.mutable_cpu_data();
    this->transformed_data_.set_cpu_data(top_data + offset);
    this->data_transformer_->Transform(datum, &(this->transformed_data_));
    // Copy label.标签数据
    if (this->output_labels_) {
      Dtype* top_label = batch->label_.mutable_cpu_data();
      top_label[item_id] = datum.label();
    }
    trans_time += timer.MicroSeconds();
    Next();
  }
  timer.Stop();
  batch_timer.Stop();
  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}

3. H5读取

H5文件的读取与LMDB的读取方式类似,H5存储的形式类似于键值对的形式,并且在里面对数据进行了充分的shuffle操作,避免了生成数据阶段没有shuffle而造成的训练失败的情况。
顺带说一句,要是图像分类中训练的图片没有shuffle会存在什么为题呢?-_-||,那就是输出的分类概率就跟撞天婚一样,而且学不动…
对于H5文件的读取,首先读取其list文件,并且按照键值对得到输出数据的维度,并以此来设置当前层的输出维度。

template <typename Dtype>
void HDF5DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
  // Refuse transformation parameters since HDF5 is totally generic.
  CHECK(!this->layer_param_.has_transform_param()) <<
      this->type() << " does not transform data.";
  // Read the source to parse the filenames.
  const string& source = this->layer_param_.hdf5_data_param().source(); //h5文件列表的地址
  LOG(INFO) << "Loading list of HDF5 filenames from: " << source;
  hdf_filenames_.clear(); //H5文件列表
  //读取文件中的所有的h5文件(文件是绝对路径存在),存到一个vector<string>中去
  std::ifstream source_file(source.c_str()); 
  if (source_file.is_open()) {
    std::string line;
    while (source_file >> line) {
      hdf_filenames_.push_back(line);
    }
  } else {
    LOG(FATAL) << "Failed to open source file: " << source;
  }
  source_file.close();
  num_files_ = hdf_filenames_.size(); //所有H5文件的个数
  current_file_ = 0;
  LOG(INFO) << "Number of HDF5 files: " << num_files_;
  CHECK_GE(num_files_, 1) << "Must have at least 1 HDF5 filename listed in "
    << source;

  file_permutation_.clear();
  file_permutation_.resize(num_files_); //训练的H5文件排队列表,后面将其打乱
  // Default to identity permutation.
  for (int i = 0; i < num_files_; i++) {
    file_permutation_[i] = i;
  }

  // Shuffle if needed. 打乱文件排序
  if (this->layer_param_.hdf5_data_param().shuffle()) {
    std::random_shuffle(file_permutation_.begin(), file_permutation_.end());
  }

  // Load the first HDF5 file and initialize the line counter. 读取H5文件
  LoadHDF5FileData(hdf_filenames_[file_permutation_[current_file_]].c_str());
  current_row_ = 0;

  // Reshape blobs.
  const int batch_size = this->layer_param_.hdf5_data_param().batch_size(); //当前层的batch size
  const int top_size = this->layer_param_.top_size(); // 当前层的是输出个数,默认第一个为训练data,第二个为label
  vector<int> top_shape;
  //按照读取数据的维度,设置输出blob的维度
  for (int i = 0; i < top_size; ++i) {
    top_shape.resize(hdf_blobs_[i]->num_axes());
    top_shape[0] = batch_size;
    for (int j = 1; j < top_shape.size(); ++j) {
      top_shape[j] = hdf_blobs_[i]->shape(j);
    }
    top[i]->Reshape(top_shape);
  }
}

接下来就是读取H5文件了,每次都会把H5文件中的图片和lable全部读取完的,全部放在内存中,一次取一个batch

template <typename Dtype>
void HDF5DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
  // Refuse transformation parameters since HDF5 is totally generic.
  CHECK(!this->layer_param_.has_transform_param()) <<
      this->type() << " does not transform data.";
  // Read the source to parse the filenames.
  const string& source = this->layer_param_.hdf5_data_param().source(); //h5文件列表的地址
  LOG(INFO) << "Loading list of HDF5 filenames from: " << source;
  hdf_filenames_.clear(); //H5文件列表
  //读取文件中的所有的h5文件(文件是绝对路径存在),存到一个vector<string>中去
  std::ifstream source_file(source.c_str()); 
  if (source_file.is_open()) {
    std::string line;
    while (source_file >> line) {
      hdf_filenames_.push_back(line);
    }
  } else {
    LOG(FATAL) << "Failed to open source file: " << source;
  }
  source_file.close();
  num_files_ = hdf_filenames_.size(); //所有H5文件的个数
  current_file_ = 0;
  LOG(INFO) << "Number of HDF5 files: " << num_files_;
  CHECK_GE(num_files_, 1) << "Must have at least 1 HDF5 filename listed in "
    << source;

  file_permutation_.clear();
  file_permutation_.resize(num_files_); //训练的H5文件排队列表,后面将其打乱
  // Default to identity permutation.
  for (int i = 0; i < num_files_; i++) {
    file_permutation_[i] = i;
  }

  // Shuffle if needed. 打乱文件排序
  if (this->layer_param_.hdf5_data_param().shuffle()) {
    std::random_shuffle(file_permutation_.begin(), file_permutation_.end());
  }

  // Load the first HDF5 file and initialize the line counter. 读取H5文件
  LoadHDF5FileData(hdf_filenames_[file_permutation_[current_file_]].c_str());
  current_row_ = 0;

  // Reshape blobs.
  const int batch_size = this->layer_param_.hdf5_data_param().batch_size(); //当前层的batch size
  const int top_size = this->layer_param_.top_size(); // 当前层的是输出个数,默认第一个为训练data,第二个为label
  vector<int> top_shape;
  //按照读取数据的维度,设置输出blob的维度
  for (int i = 0; i < top_size; ++i) {
    top_shape.resize(hdf_blobs_[i]->num_axes());
    top_shape[0] = batch_size;
    for (int j = 1; j < top_shape.size(); ++j) {
      top_shape[j] = hdf_blobs_[i]->shape(j);
    }
    top[i]->Reshape(top_shape);
  }
}

这是H5文件的读入,Caffe中还实现了H5文件的输出,有兴趣的可以参考HDF5OutputLayer类的实现。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值