代码笔记:caffe-reid中reid_data_layer源码解析

#include <stdint.h>
#include <cfloat>

#include <vector>

#include "caffe/data_transformer.hpp"
#include "caffe/layers/reid_data_layer.hpp"
#include "caffe/util/benchmark.hpp"
#include <boost/thread.hpp>

namespace caffe {

template <typename Dtype>
ReidDataLayer<Dtype>::~ReidDataLayer() {
  this->StopInternalThread();
}

template <typename Dtype>
unsigned int ReidDataLayer<Dtype>::RandRng() {
  CHECK(prefetch_rng_);
  caffe::rng_t *prefetch_rng =
      static_cast<caffe::rng_t *>(prefetch_rng_->generator());
  return (*prefetch_rng)();
}
(1)改写了ImageDataLayer中的DataLayerSetUP函数
template <typename Dtype>
void ReidDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
  DLOG(INFO) << "ReidDataLayer : DataLayerSetUp";

  // Main Data Layer Set up
  const int new_height = this->layer_param_.reid_data_param().new_height();
  const int new_width  = this->layer_param_.reid_data_param().new_width();
  const bool is_color  = this->layer_param_.reid_data_param().is_color();

  CHECK((new_height == 0 && new_width == 0) ||
      (new_height > 0 && new_width > 0)) << "Current implementation requires "
      "new_height and new_width to be set at the same time.";
  //读取图像文件和相应的label
  // Read the file with filenames and labels
  const string& source = this->layer_param_.reid_data_param().source();
  LOG(INFO) << "Opening file " << source;
  std::ifstream infile(source.c_str());
  string line;
  int mx_label = -1;
  int mi_label = INT_MAX;
  //按行读取,将行结果存为line
  while (std::getline(infile, line)) {
    size_t pos = line.find_last_of(' ');
    int label = atoi(line.substr(pos + 1).c_str());
    ///home/luoze/dataset/Market-1501-v15.09.15/bounding_box_train/0002_c1s1_000451_03.jpg 0
    //以空格为分割点来分开line,前面为path,后面为标签
    //vector<std::pair<std::string, int> > lines_;
    this->lines_.push_back(std::make_pair(line.substr(0, pos), label));
    mx_label = std::max(mx_label, label);
    mi_label = std::min(mi_label, label);
  }
  //equal
  CHECK_EQ(mi_label, 0);
  this->label_set.clear();
  //vector<vector<size_t> > label_set;
  //mx_label = 750
  this->label_set.resize(mx_label+1);
  //lines_.size()是样本个数
  for (size_t index = 0; index < this->lines_.size(); index++) {
    int label = this->lines_[index].second;
    this->label_set[label].push_back(index);
  }
  for (size_t index = 0; index < this->label_set.size(); index++) {
    CHECK_GT(this->label_set[index].size(), 0) << "label : " << index << " has no images";
  }

  CHECK(!lines_.empty()) << "File is empty";
  infile.close();

  LOG(INFO) << "A total of " << lines_.size() << " images. Label : [" << mi_label << ", " << mx_label << "]";
  LOG(INFO) << "A total of " << label_set.size() << " persons";

  //随机因子
  const unsigned int prefetch_rng_seed = caffe_rng_rand();
  prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));

  //初始化为图片的总数
  this->left_images = this->lines_.size();
  this->pos_fraction = this->layer_param_.reid_data_param().pos_fraction();
  this->neg_fraction = this->layer_param_.reid_data_param().pos_fraction();

  CHECK_GT(lines_.size(), 0);
  //开始根据path来取图了
  //vector<cv::Mat> cv_imgs_;
  this->cv_imgs_.clear();
  for (size_t lines_id_ = 0; lines_id_ < this->lines_.size(); lines_id_++) {
    cv::Mat cv_img = ReadImageToCVMat(lines_[lines_id_].first, new_height, new_width, is_color);
    CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;
    this->cv_imgs_.push_back(cv_img);
  }
  //使用Opencv来读进图像,然后使用它初始化相应的top blob
  // Read an image, and use it to initialize the top blob.
  cv::Mat cv_img = ReadImageToCVMat(lines_[0].first,
                                    new_height, new_width, is_color);
  CHECK(cv_img.data) << "Could not load " << lines_[0].first;

  const int batch_size = this->layer_param_.reid_data_param().batch_size();

  // Use data_transformer to infer the expected blob shape from datum.
  //top_shape 输出的形状
  //使用data_transformer 来计算根据datum的期望blob的shape
  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
  vector<int> prefetch_top_shape = top_shape;
  this->transformed_data_.Reshape(top_shape);
  //首先reshape top[0],再根据batch的大小进行预取
  // Reshape top[0] and prefetch_data according to the batch_size.
  top_shape[0] = batch_size * 2;
  prefetch_top_shape[0] = batch_size;
  top[0]->Reshape(top_shape);
  //top[1]->Reshape(top_shape);
  for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
    //同时预取了两组数据
    this->prefetch_[i].data_.Reshape(prefetch_top_shape);
    this->prefetch_[i].datap_.Reshape(prefetch_top_shape);
  }
  //256 3 277 277
  LOG(INFO) << "output data size: " << top[0]->num() << ","
      << top[0]->channels() << "," << top[0]->height() << ","
      << top[0]->width();
  //LOG(INFO) << "output data pair size: " << top[1]->num() << ","
  //    << top[1]->channels() << "," << top[1]->height() << ","
  //    << top[1]->width();
  // label
  if (this->output_labels_) {
    vector<int> label_shape(1, batch_size*2);
    top[1]->Reshape(label_shape);
    vector<int> prefetch_label_shape(1, batch_size);
    for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
      //同时预取了3组数据
      this->prefetch_[i].label_.Reshape(prefetch_label_shape);
      this->prefetch_[i].labelp_.Reshape(prefetch_label_shape);
      this->prefetch_[i].labeldif_.Reshape(prefetch_label_shape);
    }
    //256(256)
    LOG(INFO) << "output label size : " << top[1]->shape_string();
  }
}
(2// This function is called on prefetch thread
template<typename Dtype>
void ReidDataLayer<Dtype>::load_batch(ReidBatch  <Dtype>* batch) {
  CPUTimer batch_timer;
  batch_timer.Start();
  double read_time = 0;
  double trans_time = 0;
  CPUTimer timer;
  CHECK(batch->data_.count());
  CHECK(this->transformed_data_.count());
  const int batch_size = this->layer_param_.reid_data_param().batch_size();
  //完全随机的取值吗?
  //一组长度为batch_size的图像ID
  const vector<size_t> batches = this->batch_ids();
  //一组与batches对应类别的的长度为batch_size的图像ID
  const vector<size_t> batches_pair = this->batch_pairs(batches);

  CHECK_EQ(batches.size(), batch_size);
  CHECK_EQ(batches_pair.size(), batch_size);
  // Reshape according to the first image of each batch
  // on single input batches allows for inputs of varying dimension.
  cv::Mat cv_img = this->cv_imgs_[batches[0]];
  CHECK(cv_img.data) << "Could not load " << this->lines_[batches[0]].first;
  // Use data_transformer to infer the expected blob shape from a cv_img.
  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
  this->transformed_data_.Reshape(top_shape);
  // Reshape batch according to the batch_size.
  top_shape[0] = batch_size;
  //ReidBatch<Dtype>* batch
  batch->data_.Reshape(top_shape);
  batch->datap_.Reshape(top_shape);

  Dtype* prefetch_data = batch->data_.mutable_cpu_data();
  Dtype* prefetch_datap = batch->datap_.mutable_cpu_data();
  Dtype* prefetch_label = batch->label_.mutable_cpu_data();
  Dtype* prefetch_labelp = batch->labelp_.mutable_cpu_data();
  Dtype* prefetch_labeldif = batch->labeldif_.mutable_cpu_data();

  for (int item_id = 0; item_id < batch_size; ++item_id) {
    // get a blob
    timer.Start();
    //两张图ID
    const size_t true_idx = batches[item_id];
    const size_t pair_idx = batches_pair[item_id];
    //两张图
    cv::Mat cv_img_true = this->cv_imgs_[ true_idx ];
    cv::Mat cv_img_pair = this->cv_imgs_[ pair_idx ];
    CHECK(cv_img_true.data) << "Could not load " << this->lines_[true_idx].first;
    CHECK(cv_img_pair.data) << "Could not load " << this->lines_[pair_idx].first;
    read_time += timer.MicroSeconds();
    timer.Start();

    // Apply transformations (mirror, crop...) to the image
    const int t_offset = batch->data_.offset(item_id);
    this->transformed_data_.set_cpu_data(prefetch_data + t_offset);
    this->data_transformer_->Transform(cv_img_true, &(this->transformed_data_));

    // Pair Data
    const int p_offset = batch->datap_.offset(item_id);
    this->transformed_data_.set_cpu_data(prefetch_datap + p_offset);
    this->data_transformer_->Transform(cv_img_pair, &(this->transformed_data_));
    trans_time += timer.MicroSeconds();

    CHECK_GE(lines_[true_idx].second, 0);
    CHECK_GE(lines_[pair_idx].second, 0);
    CHECK_LT(lines_[true_idx].second, this->label_set.size());
    CHECK_LT(lines_[pair_idx].second, this->label_set.size());

    prefetch_label[item_id]    = lines_[true_idx].second;
    prefetch_labelp[item_id]   = lines_[pair_idx].second;

    //labeldif变成后文最大的悬念之一
    prefetch_labeldif[item_id] = lines_[true_idx].second == lines_[pair_idx].second;

    DLOG(INFO) << "Idx : " << item_id << " : " << lines_[true_idx].second << " vs " << lines_[pair_idx].second << " ..=.. " << prefetch_labeldif[item_id];
  }
  batch_timer.Stop();
  DLOG(INFO) << "Pair Idx : (" << batches[0] << "," << batches_pair[0] << ")";
  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}

INSTANTIATE_CLASS(ReidDataLayer);
REGISTER_LAYER_CLASS(ReidData);

}  // namespace caffe
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值