#include <stdint.h>
#include <cfloat>
#include <vector>
#include "caffe/data_transformer.hpp"
#include "caffe/layers/reid_data_layer.hpp"
#include "caffe/util/benchmark.hpp"
#include <boost/thread.hpp>
namespace caffe {
template <typename Dtype>
ReidDataLayer<Dtype>::~ReidDataLayer() {
this->StopInternalThread();
}
template <typename Dtype>
unsigned int ReidDataLayer<Dtype>::RandRng() {
CHECK(prefetch_rng_);
caffe::rng_t *prefetch_rng =
static_cast<caffe::rng_t *>(prefetch_rng_->generator());
return (*prefetch_rng)();
}
(1)改写了ImageDataLayer中的DataLayerSetUP函数
template <typename Dtype>
void ReidDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
DLOG(INFO) << "ReidDataLayer : DataLayerSetUp";
// Main Data Layer Set up
const int new_height = this->layer_param_.reid_data_param().new_height();
const int new_width = this->layer_param_.reid_data_param().new_width();
const bool is_color = this->layer_param_.reid_data_param().is_color();
CHECK((new_height == 0 && new_width == 0) ||
(new_height > 0 && new_width > 0)) << "Current implementation requires "
"new_height and new_width to be set at the same time.";
//读取图像文件和相应的label
// Read the file with filenames and labels
const string& source = this->layer_param_.reid_data_param().source();
LOG(INFO) << "Opening file " << source;
std::ifstream infile(source.c_str());
string line;
int mx_label = -1;
int mi_label = INT_MAX;
//按行读取,将行结果存为line
while (std::getline(infile, line)) {
size_t pos = line.find_last_of(' ');
int label = atoi(line.substr(pos + 1).c_str());
///home/luoze/dataset/Market-1501-v15.09.15/bounding_box_train/0002_c1s1_000451_03.jpg 0
//以空格为分割点来分开line,前面为path,后面为标签
//vector<std::pair<std::string, int> > lines_;
this->lines_.push_back(std::make_pair(line.substr(0, pos), label));
mx_label = std::max(mx_label, label);
mi_label = std::min(mi_label, label);
}
//equal
CHECK_EQ(mi_label, 0);
this->label_set.clear();
//vector<vector<size_t> > label_set;
//mx_label = 750
this->label_set.resize(mx_label+1);
//lines_.size()是样本个数
for (size_t index = 0; index < this->lines_.size(); index++) {
int label = this->lines_[index].second;
this->label_set[label].push_back(index);
}
for (size_t index = 0; index < this->label_set.size(); index++) {
CHECK_GT(this->label_set[index].size(), 0) << "label : " << index << " has no images";
}
CHECK(!lines_.empty()) << "File is empty";
infile.close();
LOG(INFO) << "A total of " << lines_.size() << " images. Label : [" << mi_label << ", " << mx_label << "]";
LOG(INFO) << "A total of " << label_set.size() << " persons";
//随机因子
const unsigned int prefetch_rng_seed = caffe_rng_rand();
prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
//初始化为图片的总数
this->left_images = this->lines_.size();
this->pos_fraction = this->layer_param_.reid_data_param().pos_fraction();
this->neg_fraction = this->layer_param_.reid_data_param().pos_fraction();
CHECK_GT(lines_.size(), 0);
//开始根据path来取图了
//vector<cv::Mat> cv_imgs_;
this->cv_imgs_.clear();
for (size_t lines_id_ = 0; lines_id_ < this->lines_.size(); lines_id_++) {
cv::Mat cv_img = ReadImageToCVMat(lines_[lines_id_].first, new_height, new_width, is_color);
CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;
this->cv_imgs_.push_back(cv_img);
}
//使用Opencv来读进图像,然后使用它初始化相应的top blob
// Read an image, and use it to initialize the top blob.
cv::Mat cv_img = ReadImageToCVMat(lines_[0].first,
new_height, new_width, is_color);
CHECK(cv_img.data) << "Could not load " << lines_[0].first;
const int batch_size = this->layer_param_.reid_data_param().batch_size();
// Use data_transformer to infer the expected blob shape from datum.
//top_shape 输出的形状
//使用data_transformer 来计算根据datum的期望blob的shape
vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
vector<int> prefetch_top_shape = top_shape;
this->transformed_data_.Reshape(top_shape);
//首先reshape top[0],再根据batch的大小进行预取
// Reshape top[0] and prefetch_data according to the batch_size.
top_shape[0] = batch_size * 2;
prefetch_top_shape[0] = batch_size;
top[0]->Reshape(top_shape);
//top[1]->Reshape(top_shape);
for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
//同时预取了两组数据
this->prefetch_[i].data_.Reshape(prefetch_top_shape);
this->prefetch_[i].datap_.Reshape(prefetch_top_shape);
}
//256 3 277 277
LOG(INFO) << "output data size: " << top[0]->num() << ","
<< top[0]->channels() << "," << top[0]->height() << ","
<< top[0]->width();
//LOG(INFO) << "output data pair size: " << top[1]->num() << ","
// << top[1]->channels() << "," << top[1]->height() << ","
// << top[1]->width();
// label
if (this->output_labels_) {
vector<int> label_shape(1, batch_size*2);
top[1]->Reshape(label_shape);
vector<int> prefetch_label_shape(1, batch_size);
for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
//同时预取了3组数据
this->prefetch_[i].label_.Reshape(prefetch_label_shape);
this->prefetch_[i].labelp_.Reshape(prefetch_label_shape);
this->prefetch_[i].labeldif_.Reshape(prefetch_label_shape);
}
//256(256)
LOG(INFO) << "output label size : " << top[1]->shape_string();
}
}
(2)
// This function is called on prefetch thread
template<typename Dtype>
void ReidDataLayer<Dtype>::load_batch(ReidBatch <Dtype>* batch) {
CPUTimer batch_timer;
batch_timer.Start();
double read_time = 0;
double trans_time = 0;
CPUTimer timer;
CHECK(batch->data_.count());
CHECK(this->transformed_data_.count());
const int batch_size = this->layer_param_.reid_data_param().batch_size();
//完全随机的取值吗?
//一组长度为batch_size的图像ID
const vector<size_t> batches = this->batch_ids();
//一组与batches对应类别的的长度为batch_size的图像ID
const vector<size_t> batches_pair = this->batch_pairs(batches);
CHECK_EQ(batches.size(), batch_size);
CHECK_EQ(batches_pair.size(), batch_size);
// Reshape according to the first image of each batch
// on single input batches allows for inputs of varying dimension.
cv::Mat cv_img = this->cv_imgs_[batches[0]];
CHECK(cv_img.data) << "Could not load " << this->lines_[batches[0]].first;
// Use data_transformer to infer the expected blob shape from a cv_img.
vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
this->transformed_data_.Reshape(top_shape);
// Reshape batch according to the batch_size.
top_shape[0] = batch_size;
//ReidBatch<Dtype>* batch
batch->data_.Reshape(top_shape);
batch->datap_.Reshape(top_shape);
Dtype* prefetch_data = batch->data_.mutable_cpu_data();
Dtype* prefetch_datap = batch->datap_.mutable_cpu_data();
Dtype* prefetch_label = batch->label_.mutable_cpu_data();
Dtype* prefetch_labelp = batch->labelp_.mutable_cpu_data();
Dtype* prefetch_labeldif = batch->labeldif_.mutable_cpu_data();
for (int item_id = 0; item_id < batch_size; ++item_id) {
// get a blob
timer.Start();
//两张图ID
const size_t true_idx = batches[item_id];
const size_t pair_idx = batches_pair[item_id];
//两张图
cv::Mat cv_img_true = this->cv_imgs_[ true_idx ];
cv::Mat cv_img_pair = this->cv_imgs_[ pair_idx ];
CHECK(cv_img_true.data) << "Could not load " << this->lines_[true_idx].first;
CHECK(cv_img_pair.data) << "Could not load " << this->lines_[pair_idx].first;
read_time += timer.MicroSeconds();
timer.Start();
// Apply transformations (mirror, crop...) to the image
const int t_offset = batch->data_.offset(item_id);
this->transformed_data_.set_cpu_data(prefetch_data + t_offset);
this->data_transformer_->Transform(cv_img_true, &(this->transformed_data_));
// Pair Data
const int p_offset = batch->datap_.offset(item_id);
this->transformed_data_.set_cpu_data(prefetch_datap + p_offset);
this->data_transformer_->Transform(cv_img_pair, &(this->transformed_data_));
trans_time += timer.MicroSeconds();
CHECK_GE(lines_[true_idx].second, 0);
CHECK_GE(lines_[pair_idx].second, 0);
CHECK_LT(lines_[true_idx].second, this->label_set.size());
CHECK_LT(lines_[pair_idx].second, this->label_set.size());
prefetch_label[item_id] = lines_[true_idx].second;
prefetch_labelp[item_id] = lines_[pair_idx].second;
//labeldif变成后文最大的悬念之一
prefetch_labeldif[item_id] = lines_[true_idx].second == lines_[pair_idx].second;
DLOG(INFO) << "Idx : " << item_id << " : " << lines_[true_idx].second << " vs " << lines_[pair_idx].second << " ..=.. " << prefetch_labeldif[item_id];
}
batch_timer.Stop();
DLOG(INFO) << "Pair Idx : (" << batches[0] << "," << batches_pair[0] << ")";
DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
DLOG(INFO) << " Read time: " << read_time / 1000 << " ms.";
DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}
INSTANTIATE_CLASS(ReidDataLayer);
REGISTER_LAYER_CLASS(ReidData);
} // namespace caffe
代码笔记:caffe-reid中reid_data_layer源码解析
最新推荐文章于 2024-04-17 14:34:51 发布