caffe代码阅读2:DataTransformer以及io的实现细节

一、DataTransformer的作用简介

该类主要负责对数据进行预处理,将Datum、const vector<Datum>、cv::Mat&、vector<cv::Mat> 、Blob<Dtype>*类型的数据变换到目标大小的blob。
此外还负责根据参数中指定的预处理参数推断出处理后的数据的shape。

在正式介绍之前,先给个例子:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. layer {  
  2.   name: "jointimagedata"  
  3.   type: "JointImage"  
  4.   top: "jointimagedata"  
  5.   top: "label"  
  6.   include {  
  7.     phase: TEST  
  8.   }  
  9.   transform_param {  
  10.     mirror: true  
  11.     crop_size: 227  
  12.     mean_file: "data/ilsvrc12/imagenet_mean.binaryproto"  
  13.   }  
  14.   slidewindow_param {  
  15.     root_folder: "D:/数据集/FLIC/FLIC-full"  
  16.     filelistpath: "/imglist.txt"  
  17.     batch_size: 300  
  18.   }  
  19. }  

上述配置文件中就包含了transform_param这个参数,利用该参数可以实现crop,mirror,减去均值等功能。
该类用到了TransformationParameter。其在caffe.proto的定义为
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. // Message that stores parameters used to apply transformation  
  2. // to the data layer's data  
  3. message TransformationParameter {  
  4.   // For data pre-processing, we can do simple scaling and subtracting the  
  5.   // data mean, if provided. Note that the mean subtraction is always carried  
  6.   // out before scaling.  
  7.   optional float scale = 1 [default = 1];  
  8.   // Specify if we want to randomly mirror data.  
  9.   optional bool mirror = 2 [default = false];  
  10.   // Specify if we would like to randomly crop an image.  
  11.   optional uint32 crop_size = 3 [default = 0];  
  12.   // mean_file and mean_value cannot be specified at the same time  
  13.   optional string mean_file = 4;  
  14.   // if specified can be repeated once (would substract it from all the channels)  
  15.   // or can be repeated the same number of times as channels  
  16.   // (would subtract them from the corresponding channel)  
  17.   repeated float mean_value = 5;  
  18.   // Force the decoded image to have 3 color channels.  
  19.   optional bool force_color = 6 [default = false];  
  20.   // Force the decoded image to have 1 color channels.  
  21.   optional bool force_gray = 7 [default = false];  
  22. }  

二、DataTransformer类的详细介绍

1)构造函数

  // 构造函数
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. explicit DataTransformer(const TransformationParameter& param, Phase phase);  
  2. virtual ~DataTransformer() {}  

2)成员变量

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. // 变换所使用的参数  
  2. TransformationParameter param_;  
  3. // 随机数生成器的种子  
  4. shared_ptr<Caffe::RNG> rng_;  
  5. // 是训练还是测试?  
  6. Phase phase_;  
  7. // 数据均值 blob  
  8. Blob<Dtype> data_mean_;  
  9. // 数据均值blob的容器  
  10. vector<Dtype> mean_values_;  

3)成员函数

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1.  // 初始化随机数生成器,因为在对数据进行变换的时候有可能用到,比如说打乱数据的输入顺序  
  2.   void InitRand();  
  3.    // 对Datum的数据进行变换,放入到transformed_blob中  
  4.   void Transform(const Datum& datum, Blob<Dtype>* transformed_blob);  
  5.    // 对Datum容器的数据进行变换翻入到transformed_blob  
  6.   void Transform(const vector<Datum> & datum_vector,  
  7.                 Blob<Dtype>* transformed_blob);  
  8.    // 如果定义OpenCV还可能对mat容器数据类型的数据进行变换  
  9.   void Transform(const vector<cv::Mat> & mat_vector,  
  10.                 Blob<Dtype>* transformed_blob);  
  11.    // 将opencv读取的单个图像转换到blob中去  
  12.   void Transform(const cv::Mat& cv_img, Blob<Dtype>* transformed_blob);  
  13.    // 将输入的blob进行变换,可能是取出blob的中的一部分数据到新的blob  
  14.   void Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob);  
  15.    // 根据Datum获取blob的形状  
  16.   vector<int> InferBlobShape(const Datum& datum);  
  17.    // 根据Datum容器获取blob的形状  
  18.   vector<int> InferBlobShape(const vector<Datum> & datum_vector);  
  19.   // 根据Mat容器获取blob的形状  
  20.   vector<int> InferBlobShape(const vector<cv::Mat> & mat_vector);  
  21.   // 根据Mat获取blob的形状  
  22.   vector<int> InferBlobShape(const cv::Mat& cv_img);  
  23. // 生成从0到n-1的服从均匀分布的随机数,要求继承他的都必须实现如何生成随机数  
  24.   virtual int Rand(int n);  
  25.   // 将给定的Datum进行转换  
  26.   void Transform(const Datum& datum, Dtype* transformed_data);  

4)具体函数的实现:

首先是构造函数
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. 在介绍构造函数之前不得不先贴出BlobShape和、BlobProto这两个结构体的在caffe.proto中的定义。  
  2. message BlobShape {  
  3.   repeated int64 dim = 1 [packed = true]; //blob的形状  
  4. }  
  5.   
  6. message BlobProto {  
  7.   optional BlobShape shape = 7;  
  8.   repeated float data = 5 [packed = true]; // 前向传播的数据  
  9.   repeated float diff = 6 [packed = true]; // 反向传播的数据  
  10.   repeated double double_data = 8 [packed = true]; // double类型的前向传播的数据  
  11.   repeated double double_diff = 9 [packed = true]; // 依次类推  
  12.   
  13.   // 4D dimensions -- deprecated.  Use "shape" instead.  
  14.   // 下面是为了兼容  
  15.   optional int32 num = 1 [default = 0];  
  16.   optional int32 channels = 2 [default = 0];  
  17.   optional int32 height = 3 [default = 0];  
  18.   optional int32 width = 4 [default = 0];  
  19. }  
  20.   
  21.   
  22. template<typename Dtype>  
  23. DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,  
  24.     Phase phase)  
  25.     : param_(param), phase_(phase) {  
  26.   // check if we want to use mean_file  
  27.   if (param_.has_mean_file()) {  
  28.     CHECK_EQ(param_.mean_value_size(), 0) <<  
  29.       "Cannot specify mean_file and mean_value at the same time";  
  30.     const string& mean_file = param.mean_file();  
  31.     if (Caffe::root_solver()) {  
  32.       LOG(INFO) << "Loading mean file from: " << mean_file;  
  33.     }  
  34.     BlobProto blob_proto;  
  35.     ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);  
  36.     data_mean_.FromProto(blob_proto);  
  37.   }  
  38.   // check if we want to use mean_value  
  39.   if (param_.mean_value_size() > 0) {  
  40.     CHECK(param_.has_mean_file() == false) <<  
  41.       "Cannot specify mean_file and mean_value at the same time";  
  42.     for (int c = 0; c < param_.mean_value_size(); ++c) {  
  43.       mean_values_.push_back(param_.mean_value(c));  
  44.     }  
  45.   }  
  46. }  
具体的实现如下:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #ifdef USE_OPENCV  
  2. #include <opencv2/core/core.hpp>  
  3. #endif  // USE_OPENCV  
  4.   
  5. #include <string>  
  6. #include <vector>  
  7.   
  8. #include "caffe/data_transformer.hpp"  
  9. #include "caffe/util/io.hpp"  
  10. #include "caffe/util/math_functions.hpp"  
  11. #include "caffe/util/rng.hpp"  
  12.   
  13. namespace caffe {  
  14. // 构造函数  
  15. template<typename Dtype>  
  16. DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,  
  17.     Phase phase)  
  18.     : param_(param), phase_(phase) {  
  19.   // check if we want to use mean_file  
  20.   // 判断是否有平均值文件  
  21.   if (param_.has_mean_file()) {  
  22.     CHECK_EQ(param_.mean_value_size(), 0) <<  
  23.       "Cannot specify mean_file and mean_value at the same time";  
  24.     // 平均值文件的路径  
  25.     const string& mean_file = param.mean_file();  
  26.     if (Caffe::root_solver()) {  
  27.       LOG(INFO) << "Loading mean file from: " << mean_file;  
  28.     }  
  29.     BlobProto blob_proto;  
  30.     ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);  
  31.     data_mean_.FromProto(blob_proto);  
  32.   }  
  33.   // check if we want to use mean_value  
  34.   if (param_.mean_value_size() > 0) {  
  35.     CHECK(param_.has_mean_file() == false) <<  
  36.       "Cannot specify mean_file and mean_value at the same time";  
  37.     for (int c = 0; c < param_.mean_value_size(); ++c) {  
  38.       mean_values_.push_back(param_.mean_value(c));  
  39.     }  
  40.   }  
  41. }  
  42.   
  43. template<typename Dtype>  
  44. void DataTransformer<Dtype>::Transform(const Datum& datum,  
  45.                                        Dtype* transformed_data) {  
  46.   // 参考TransformationParameter的定义  
  47.   const string& data = datum.data();  
  48.   const int datum_channels = datum.channels();//数据的channel  
  49.   const int datum_height = datum.height();//数据的行数  
  50.   const int datum_width = datum.width();// 数据的列数  
  51.   
  52.   const int crop_size = param_.crop_size();// crop大小  
  53.   const Dtype scale = param_.scale();// 缩放比例  
  54.   const bool do_mirror = param_.mirror() && Rand(2);// 该参数用于在镜像位置对数据处理  
  55.   const bool has_mean_file = param_.has_mean_file();// 是否有均值文件  
  56.   const bool has_uint8 = data.size() > 0;// 数据是否为uint8还是float类型的  
  57.   const bool has_mean_values = mean_values_.size() > 0;// 是否有每个channel的均值  
  58.   
  59.   // 检查合法性  
  60.   CHECK_GT(datum_channels, 0);  
  61.   CHECK_GE(datum_height, crop_size);  
  62.   CHECK_GE(datum_width, crop_size);  
  63.   
  64.   Dtype* mean = NULL;  
  65.   if (has_mean_file) {// 检查mean_file是否与数据的参数一致  
  66.     CHECK_EQ(datum_channels, data_mean_.channels());  
  67.     CHECK_EQ(datum_height, data_mean_.height());  
  68.     CHECK_EQ(datum_width, data_mean_.width());  
  69.     mean = data_mean_.mutable_cpu_data();  
  70.   }  
  71.   if (has_mean_values) {  
  72.     CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<  
  73.      "Specify either 1 mean_value or as many as channels: " << datum_channels;  
  74.     if (datum_channels > 1 && mean_values_.size() == 1) {  
  75.       // Replicate the mean_value for simplicity  
  76.       for (int c = 1; c < datum_channels; ++c) {  
  77.         mean_values_.push_back(mean_values_[0]);  
  78.       }  
  79.     }  
  80.   }  
  81.   
  82.   int height = datum_height;  
  83.   int width = datum_width;  
  84.   
  85.   // 根据是否需要crop来生成h_off和w_off  
  86.   int h_off = 0;  
  87.   int w_off = 0;  
  88.   if (crop_size) {// 如果crop_size不为0  
  89.     height = crop_size;  
  90.     width = crop_size;  
  91.     // We only do random crop when we do training.  
  92.     // 在训练的时候随机crop图像块,这里需要自己实现Rand这个函数来确定是如何随机的  
  93.     if (phase_ == TRAIN) {  
  94.       h_off = Rand(datum_height - crop_size + 1);// 产生从0到datum_height - crop_size的随机数  
  95.       w_off = Rand(datum_width - crop_size + 1);  
  96.     } else {// 测试的时候不用随机,取图像的中心  
  97.       h_off = (datum_height - crop_size) / 2;  
  98.       w_off = (datum_width - crop_size) / 2;  
  99.     }  
  100.   }  
  101.   
  102.   // 对数据进行变换,主要是将原来的像素值减去均值,然后乘以scale这么一个操作  
  103.   // 如果需要crop则最终转换的Blob的大小即为crop*crop  
  104.   // 如果不是,则最终的Blob大小即为datum_height*datum_width  
  105.   Dtype datum_element;  
  106.   int top_index, data_index;  
  107.   for (int c = 0; c < datum_channels; ++c) {  
  108.     for (int h = 0; h < height; ++h) {  
  109.       for (int w = 0; w < width; ++w) {  
  110.         data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;// 获取数据的索引  
  111.         if (do_mirror) {// 是否需要在镜像位置转换  
  112.           top_index = (c * height + h) * width + (width - 1 - w);//在宽这个坐标上做文章,来实现镜像  
  113.         } else {//  
  114.           top_index = (c * height + h) * width + w;  
  115.         }  
  116.         if (has_uint8) {// 数据如果是uint8则进行转换  
  117.           datum_element =  
  118.             static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));  
  119.         } else {// 否则就是float  
  120.           datum_element = datum.float_data(data_index);  
  121.         }  
  122.         if (has_mean_file) {// 如果有mean_file,则原来的像素值减去均值,然后乘以scale  
  123.           transformed_data[top_index] =  
  124.             (datum_element - mean[data_index]) * scale;  
  125.         } else {  
  126.           if (has_mean_values) {// 否则减去该channel的均值(每个channel有其一个均值),然后乘以scale  
  127.             transformed_data[top_index] =  
  128.               (datum_element - mean_values_[c]) * scale;  
  129.           } else {// 否则如果没有均值那么就直接乘以scale即可  
  130.             transformed_data[top_index] = datum_element * scale;  
  131.           }  
  132.         }  
  133.       }  
  134.     }  
  135.   }  
  136. }  
  137.   
  138.   
  139. template<typename Dtype>  
  140. void DataTransformer<Dtype>::Transform(const Datum& datum,  
  141.                                        Blob<Dtype>* transformed_blob) {  
  142.   // If datum is encoded, decoded and transform the cv::image.  
  143.   if (datum.encoded()) {//  检查是否编码了,如果是则解码  
  144. #ifdef USE_OPENCV  
  145.     // 先检查是不是两个属性都设置, 如果是则说明参数设置有误  
  146.     CHECK(!(param_.force_color() && param_.force_gray()))  
  147.         << "cannot set both force_color and force_gray";  
  148.     cv::Mat cv_img;  
  149.     if (param_.force_color() || param_.force_gray()) {  
  150.         // 如果强制彩色或者强制灰度图像一个成立则使用DecodeDatumToCVMat解码  
  151.     // If force_color then decode in color otherwise decode in gray.  
  152.       cv_img = DecodeDatumToCVMat(datum, param_.force_color());  
  153.     } else {// 否则使用DecodeDatumToCVMatNative解码  
  154.       cv_img = DecodeDatumToCVMatNative(datum);  
  155.     }  
  156.     // Transform the cv::image into blob.  
  157.     // 变换  
  158.     return Transform(cv_img, transformed_blob);  
  159. #else  
  160.     LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";  
  161. #endif  // USE_OPENCV  
  162.   } else {// 如果没有编码则,检查force_color和force_gray是否设置,如果设置则不合法,因为该选项只适合于编码后的数据  
  163.     if (param_.force_color() || param_.force_gray()) {  
  164.       LOG(ERROR) << "force_color and force_gray only for encoded datum";  
  165.     }  
  166.   }  
  167.   
  168.   const int crop_size = param_.crop_size();  
  169.   const int datum_channels = datum.channels();  
  170.   const int datum_height = datum.height();  
  171.   const int datum_width = datum.width();  
  172.   
  173.   // Check dimensions.  
  174.   const int channels = transformed_blob->channels();  
  175.   const int height = transformed_blob->height();  
  176.   const int width = transformed_blob->width();  
  177.   const int num = transformed_blob->num();  
  178.   
  179.   CHECK_EQ(channels, datum_channels);  
  180.   CHECK_LE(height, datum_height);  
  181.   CHECK_LE(width, datum_width);  
  182.   CHECK_GE(num, 1);  
  183.   
  184.   if (crop_size) {  
  185.     CHECK_EQ(crop_size, height);  
  186.     CHECK_EQ(crop_size, width);  
  187.   } else {  
  188.     CHECK_EQ(datum_height, height);  
  189.     CHECK_EQ(datum_width, width);  
  190.   }  
  191.   // 继续变换数据  
  192.   Dtype* transformed_data = transformed_blob->mutable_cpu_data();  
  193.   Transform(datum, transformed_data);  
  194. }  
  195.   
  196. template<typename Dtype>  
  197. void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,  
  198.                                        Blob<Dtype>* transformed_blob) {  
  199.   const int datum_num = datum_vector.size();  
  200.   // 变换到的目标blob的形状  
  201.   const int num = transformed_blob->num();  
  202.   const int channels = transformed_blob->channels();  
  203.   const int height = transformed_blob->height();  
  204.   const int width = transformed_blob->width();  
  205.   
  206.   CHECK_GT(datum_num, 0) << "There is no datum to add";  
  207.   CHECK_LE(datum_num, num) <<  
  208.     "The size of datum_vector must be no greater than transformed_blob->num()";  
  209.   // 新建一个uni_blob,里面只有一个batch  
  210.   Blob<Dtype> uni_blob(1, channels, height, width);  
  211.   for (int item_id = 0; item_id < datum_num; ++item_id) {  
  212.     int offset = transformed_blob->offset(item_id);  
  213.     uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);  
  214.     Transform(datum_vector[item_id], &uni_blob);  
  215.   }  
  216. }  
  217.   
  218. #ifdef USE_OPENCV  
  219. template<typename Dtype>  
  220. void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,  
  221.                                        Blob<Dtype>* transformed_blob) {  
  222.   // 获取mat的参数  
  223.   const int mat_num = mat_vector.size();  
  224.   const int num = transformed_blob->num();  
  225.   const int channels = transformed_blob->channels();  
  226.   const int height = transformed_blob->height();  
  227.   const int width = transformed_blob->width();  
  228.   
  229.   CHECK_GT(mat_num, 0) << "There is no MAT to add";  
  230.   CHECK_EQ(mat_num, num) <<  
  231.     "The size of mat_vector must be equals to transformed_blob->num()";  
  232.   //  同上  
  233.   Blob<Dtype> uni_blob(1, channels, height, width);  
  234.   for (int item_id = 0; item_id < mat_num; ++item_id) {  
  235.     int offset = transformed_blob->offset(item_id);  
  236.     uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);  
  237.     Transform(mat_vector[item_id], &uni_blob);  
  238.   }  
  239. }  
  240.   
  241. // 如果是图像的话,需要减去均值乘以scale,判断是不是需要做镜像处理  
  242. // 逻辑与前面类似  
  243. template<typename Dtype>  
  244. void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,  
  245.                                        Blob<Dtype>* transformed_blob) {  
  246.   const int crop_size = param_.crop_size();  
  247.   const int img_channels = cv_img.channels();  
  248.   const int img_height = cv_img.rows;  
  249.   const int img_width = cv_img.cols;  
  250.   
  251.   // Check dimensions.  
  252.   const int channels = transformed_blob->channels();  
  253.   const int height = transformed_blob->height();  
  254.   const int width = transformed_blob->width();  
  255.   const int num = transformed_blob->num();  
  256.   
  257.   CHECK_EQ(channels, img_channels);  
  258.   CHECK_LE(height, img_height);  
  259.   CHECK_LE(width, img_width);  
  260.   CHECK_GE(num, 1);  
  261.   
  262.   CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";  
  263.   
  264.   const Dtype scale = param_.scale();  
  265.   const bool do_mirror = param_.mirror() && Rand(2);  
  266.   const bool has_mean_file = param_.has_mean_file();  
  267.   const bool has_mean_values = mean_values_.size() > 0;  
  268.   
  269.   CHECK_GT(img_channels, 0);  
  270.   CHECK_GE(img_height, crop_size);  
  271.   CHECK_GE(img_width, crop_size);  
  272.   
  273.   Dtype* mean = NULL;  
  274.   if (has_mean_file) {  
  275.     CHECK_EQ(img_channels, data_mean_.channels());  
  276.     CHECK_EQ(img_height, data_mean_.height());  
  277.     CHECK_EQ(img_width, data_mean_.width());  
  278.     mean = data_mean_.mutable_cpu_data();  
  279.   }  
  280.   if (has_mean_values) {  
  281.     CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) <<  
  282.      "Specify either 1 mean_value or as many as channels: " << img_channels;  
  283.     if (img_channels > 1 && mean_values_.size() == 1) {  
  284.       // Replicate the mean_value for simplicity  
  285.       for (int c = 1; c < img_channels; ++c) {  
  286.         mean_values_.push_back(mean_values_[0]);  
  287.       }  
  288.     }  
  289.   }  
  290.   
  291.   int h_off = 0;  
  292.   int w_off = 0;  
  293.   cv::Mat cv_cropped_img = cv_img;  
  294.   if (crop_size) {  
  295.     CHECK_EQ(crop_size, height);  
  296.     CHECK_EQ(crop_size, width);  
  297.     // We only do random crop when we do training.  
  298.     if (phase_ == TRAIN) {  
  299.       h_off = Rand(img_height - crop_size + 1);  
  300.       w_off = Rand(img_width - crop_size + 1);  
  301.     } else {  
  302.       h_off = (img_height - crop_size) / 2;  
  303.       w_off = (img_width - crop_size) / 2;  
  304.     }  
  305.     cv::Rect roi(w_off, h_off, crop_size, crop_size);  
  306.     cv_cropped_img = cv_img(roi);  
  307.   } else {  
  308.     CHECK_EQ(img_height, height);  
  309.     CHECK_EQ(img_width, width);  
  310.   }  
  311.   
  312.   CHECK(cv_cropped_img.data);  
  313.   
  314.   Dtype* transformed_data = transformed_blob->mutable_cpu_data();  
  315.   int top_index;  
  316.   for (int h = 0; h < height; ++h) {  
  317.     const uchar* ptr = cv_cropped_img.ptr<uchar>(h);  
  318.     int img_index = 0;  
  319.     for (int w = 0; w < width; ++w) {  
  320.       for (int c = 0; c < img_channels; ++c) {  
  321.         if (do_mirror) {  
  322.           top_index = (c * height + h) * width + (width - 1 - w);  
  323.         } else {  
  324.           top_index = (c * height + h) * width + w;  
  325.         }  
  326.         // int top_index = (c * height + h) * width + w;  
  327.         Dtype pixel = static_cast<Dtype>(ptr[img_index++]);  
  328.         if (has_mean_file) {  
  329.           int mean_index = (c * img_height + h_off + h) * img_width + w_off + w;  
  330.           transformed_data[top_index] =  
  331.             (pixel - mean[mean_index]) * scale;  
  332.         } else {  
  333.           if (has_mean_values) {  
  334.             transformed_data[top_index] =  
  335.               (pixel - mean_values_[c]) * scale;  
  336.           } else {  
  337.             transformed_data[top_index] = pixel * scale;  
  338.           }  
  339.         }  
  340.       }  
  341.     }  
  342.   }  
  343. }  
  344. #endif  // USE_OPENCV  
  345.   
  346. template<typename Dtype>  
  347. void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,  
  348.                                        Blob<Dtype>* transformed_blob) {  
  349.   const int crop_size = param_.crop_size();  
  350.   const int input_num = input_blob->num();  
  351.   const int input_channels = input_blob->channels();  
  352.   const int input_height = input_blob->height();  
  353.   const int input_width = input_blob->width();  
  354.   
  355.   if (transformed_blob->count() == 0) {  
  356.     // Initialize transformed_blob with the right shape.  
  357.     if (crop_size) {  
  358.       transformed_blob->Reshape(input_num, input_channels,  
  359.                                 crop_size, crop_size);  
  360.     } else {  
  361.       transformed_blob->Reshape(input_num, input_channels,  
  362.                                 input_height, input_width);  
  363.     }  
  364.   }  
  365.   
  366.   const int num = transformed_blob->num();  
  367.   const int channels = transformed_blob->channels();  
  368.   const int height = transformed_blob->height();  
  369.   const int width = transformed_blob->width();  
  370.   const int size = transformed_blob->count();  
  371.   
  372.   CHECK_LE(input_num, num);  
  373.   CHECK_EQ(input_channels, channels);  
  374.   CHECK_GE(input_height, height);  
  375.   CHECK_GE(input_width, width);  
  376.   
  377.   
  378.   const Dtype scale = param_.scale();  
  379.   const bool do_mirror = param_.mirror() && Rand(2);  
  380.   const bool has_mean_file = param_.has_mean_file();  
  381.   const bool has_mean_values = mean_values_.size() > 0;  
  382.   
  383.   int h_off = 0;  
  384.   int w_off = 0;  
  385.   if (crop_size) {  
  386.     CHECK_EQ(crop_size, height);  
  387.     CHECK_EQ(crop_size, width);  
  388.     // We only do random crop when we do training.  
  389.     if (phase_ == TRAIN) {  
  390.       h_off = Rand(input_height - crop_size + 1);  
  391.       w_off = Rand(input_width - crop_size + 1);  
  392.     } else {  
  393.       h_off = (input_height - crop_size) / 2;  
  394.       w_off = (input_width - crop_size) / 2;  
  395.     }  
  396.   } else {  
  397.     CHECK_EQ(input_height, height);  
  398.     CHECK_EQ(input_width, width);  
  399.   }  
  400.   
  401.   // 如果有均值文件则  
  402.   Dtype* input_data = input_blob->mutable_cpu_data();  
  403.   if (has_mean_file) {  
  404.     CHECK_EQ(input_channels, data_mean_.channels());  
  405.     CHECK_EQ(input_height, data_mean_.height());  
  406.     CHECK_EQ(input_width, data_mean_.width());  
  407.     for (int n = 0; n < input_num; ++n) {  
  408.       int offset = input_blob->offset(n);  
  409.       /* 
  410.          template <typename Dtype> 
  411.        void caffe_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y); 
  412.        math_function中定义的caffe_sub目的是矩阵相减input_data(以offset开始的矩阵) = input_data(以offset开始的矩阵) - data_mean_ 
  413.     */  
  414.       caffe_sub(data_mean_.count(), input_data + offset,  
  415.             data_mean_.cpu_data(), input_data + offset);  
  416.     }  
  417.   }  
  418.   // 如果每个channel有均值则  
  419.   if (has_mean_values) {  
  420.     CHECK(mean_values_.size() == 1 || mean_values_.size() == input_channels) <<  
  421.      "Specify either 1 mean_value or as many as channels: " << input_channels;  
  422.     if (mean_values_.size() == 1) {  
  423.       caffe_add_scalar(input_blob->count(), -(mean_values_[0]), input_data);  
  424.     } else {  
  425.       for (int n = 0; n < input_num; ++n) {  
  426.         for (int c = 0; c < input_channels; ++c) {  
  427.           int offset = input_blob->offset(n, c);  
  428.           // 给nput_data[offset]地址开始的每一个元素加上一个-mean_values_[c]  
  429.           caffe_add_scalar(input_height * input_width, -(mean_values_[c]),  
  430.             input_data + offset);  
  431.         }  
  432.       }  
  433.     }  
  434.   }  
  435.   
  436.   // 如果啥均值都没有则直接复制  
  437.   Dtype* transformed_data = transformed_blob->mutable_cpu_data();  
  438.   
  439.   for (int n = 0; n < input_num; ++n) {  
  440.     int top_index_n = n * channels;  
  441.     int data_index_n = n * channels;  
  442.     for (int c = 0; c < channels; ++c) {  
  443.       int top_index_c = (top_index_n + c) * height;  
  444.       int data_index_c = (data_index_n + c) * input_height + h_off;  
  445.       for (int h = 0; h < height; ++h) {  
  446.         int top_index_h = (top_index_c + h) * width;  
  447.         int data_index_h = (data_index_c + h) * input_width + w_off;  
  448.         if (do_mirror) {  
  449.           int top_index_w = top_index_h + width - 1;  
  450.           for (int w = 0; w < width; ++w) {  
  451.             transformed_data[top_index_w-w] = input_data[data_index_h + w];  
  452.           }  
  453.         } else {  
  454.           for (int w = 0; w < width; ++w) {  
  455.             transformed_data[top_index_h + w] = input_data[data_index_h + w];  
  456.           }  
  457.         }  
  458.       }  
  459.     }  
  460.   }  
  461.   if (scale != Dtype(1)) {  
  462.     DLOG(INFO) << "Scale: " << scale;  
  463.     caffe_scal(size, scale, transformed_data);  
  464.   }  
  465. }  
  466.   
  467. template<typename Dtype>  
  468. vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {  
  469.   if (datum.encoded()) {  
  470. #ifdef USE_OPENCV // 如果使用OpenCV则可以用先转换为CVMat,然后在推断blob的形状  
  471.     CHECK(!(param_.force_color() && param_.force_gray()))  
  472.         << "cannot set both force_color and force_gray";  
  473.     cv::Mat cv_img;  
  474.     if (param_.force_color() || param_.force_gray()) {  
  475.     // If force_color then decode in color otherwise decode in gray.  
  476.       cv_img = DecodeDatumToCVMat(datum, param_.force_color());  
  477.     } else {  
  478.       cv_img = DecodeDatumToCVMatNative(datum);  
  479.     }  
  480.     // InferBlobShape using the cv::image.  
  481.     return InferBlobShape(cv_img);  
  482. #else  
  483.     LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";  
  484. #endif  // USE_OPENCV  
  485.   }  
  486.   
  487.   // 否则直接粗暴地从datum里面获取形状的数据  
  488.   const int crop_size = param_.crop_size();  
  489.   const int datum_channels = datum.channels();  
  490.   const int datum_height = datum.height();  
  491.   const int datum_width = datum.width();  
  492.   // Check dimensions.  
  493.   CHECK_GT(datum_channels, 0);  
  494.   CHECK_GE(datum_height, crop_size);  
  495.   CHECK_GE(datum_width, crop_size);  
  496.   // Build BlobShape.  
  497.   vector<int> shape(4);  
  498.   shape[0] = 1;  
  499.   shape[1] = datum_channels;  
  500.   shape[2] = (crop_size)? crop_size: datum_height;  
  501.   shape[3] = (crop_size)? crop_size: datum_width;  
  502.   return shape;  
  503. }  
  504.   
  505. template<typename Dtype>  
  506. vector<int> DataTransformer<Dtype>::InferBlobShape(  
  507.     const vector<Datum> & datum_vector) {  
  508.   const int num = datum_vector.size();  
  509.   CHECK_GT(num, 0) << "There is no datum to in the vector";  
  510.   // Use first datum in the vector to InferBlobShape.  
  511.   // 使用第一个来进行推断  
  512.   vector<int> shape = InferBlobShape(datum_vector[0]);  
  513.   // Adjust num to the size of the vector.  
  514.   shape[0] = num;  
  515.   return shape;  
  516. }  
  517.   
  518. #ifdef USE_OPENCV  
  519. // 如果使用OpenCV  
  520. // 使用CVMat中的信息来推断形状  
  521. template<typename Dtype>  
  522. vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) {  
  523.   const int crop_size = param_.crop_size();  
  524.   const int img_channels = cv_img.channels();  
  525.   const int img_height = cv_img.rows;  
  526.   const int img_width = cv_img.cols;  
  527.   // Check dimensions.  
  528.   CHECK_GT(img_channels, 0);  
  529.   CHECK_GE(img_height, crop_size);  
  530.   CHECK_GE(img_width, crop_size);  
  531.   // Build BlobShape.  
  532.   vector<int> shape(4);  
  533.   shape[0] = 1;  
  534.   shape[1] = img_channels;  
  535.   shape[2] = (crop_size)? crop_size: img_height;  
  536.   shape[3] = (crop_size)? crop_size: img_width;  
  537.   return shape;  
  538. }  
  539.   
  540. template<typename Dtype>  
  541. vector<int> DataTransformer<Dtype>::InferBlobShape(  
  542.     const vector<cv::Mat> & mat_vector) {  
  543.   const int num = mat_vector.size();  
  544.   CHECK_GT(num, 0) << "There is no cv_img to in the vector";  
  545.   // Use first cv_img in the vector to InferBlobShape.  
  546.   // 使用第一个来推断  
  547.   vector<int> shape = InferBlobShape(mat_vector[0]);  
  548.   // Adjust num to the size of the vector.  
  549.   shape[0] = num;  
  550.   return shape;  
  551. }  
  552. #endif  // USE_OPENCV  
  553.   
  554. // 初始化随机数种子  
  555. template <typename Dtype>  
  556. void DataTransformer<Dtype>::InitRand() {  
  557.   // 要么需要镜像要么训练阶段和需要crop同时满足的情况下才初始化随机数种子  
  558.   const bool needs_rand = param_.mirror() ||  
  559.       (phase_ == TRAIN && param_.crop_size());  
  560.   if (needs_rand) {  
  561.     const unsigned int rng_seed = caffe_rng_rand();// 获得随机数种子(通过熵池或者时间生成种子)  
  562.     rng_.reset(new Caffe::RNG(rng_seed));//初始化随机数种子并实例化随机数生成器  
  563.   } else {  
  564.     rng_.reset();//否则随机数生成器设置为空  
  565.   }  
  566. }  
  567.   
  568. // 产生从0到n的随机数  
  569. template <typename Dtype>  
  570. int DataTransformer<Dtype>::Rand(int n) {  
  571.   CHECK(rng_);  
  572.   CHECK_GT(n, 0);  
  573.   caffe::rng_t* rng =  
  574.       static_cast<caffe::rng_t*>(rng_->generator());  
  575.   return ((*rng)() % n);  
  576. }  
  577.   
  578. INSTANTIATE_CLASS(DataTransformer);  
  579. /* 
  580. 初始化类的宏定义是这样的,前面有讲过,这里再给出来 
  581. #define INSTANTIATE_CLASS(classname) \ 
  582.   char gInstantiationGuard##classname; \ 
  583.   template class classname<float>; \ 
  584.   template class classname<double> 
  585. */  
  586.   
  587. }  // namespace caffe  

三、与DataTransformer类相关类的介绍

(1)io的介绍

首先给出io中定义的各个函数的含义:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #ifndef CAFFE_UTIL_IO_H_  
  2. #define CAFFE_UTIL_IO_H_  
  3.   
  4. #include <unistd.h>  
  5. #include <string>  
  6.   
  7. #include "google/protobuf/message.h"  
  8.   
  9. #include "caffe/blob.hpp"  
  10. #include "caffe/common.hpp"  
  11. #include "caffe/proto/caffe.pb.h"  
  12.   
  13. namespace caffe {  
  14.   
  15. using ::google::protobuf::Message;  
  16. // 内联函数,创建临时文件  
  17. inline void MakeTempFilename(string* temp_filename) {  
  18.   temp_filename->clear();  
  19.   *temp_filename = "/tmp/caffe_test.XXXXXX";  
  20.   char* temp_filename_cstr = new char[temp_filename->size() + 1];  
  21.   // NOLINT_NEXT_LINE(runtime/printf)  
  22.   strcpy(temp_filename_cstr, temp_filename->c_str());  
  23.   int fd = mkstemp(temp_filename_cstr);  
  24.   CHECK_GE(fd, 0) << "Failed to open a temporary file at: " << *temp_filename;  
  25.   close(fd);  
  26.   *temp_filename = temp_filename_cstr;  
  27.   delete[] temp_filename_cstr;  
  28. }  
  29.   
  30. // 内联函数,创建临时目录  
  31. inline void MakeTempDir(string* temp_dirname) {  
  32.   temp_dirname->clear();  
  33.   *temp_dirname = "/tmp/caffe_test.XXXXXX";  
  34.   char* temp_dirname_cstr = new char[temp_dirname->size() + 1];  
  35.   // NOLINT_NEXT_LINE(runtime/printf)  
  36.   strcpy(temp_dirname_cstr, temp_dirname->c_str());  
  37.   char* mkdtemp_result = mkdtemp(temp_dirname_cstr);  
  38.   CHECK(mkdtemp_result != NULL)  
  39.       << "Failed to create a temporary directory at: " << *temp_dirname;  
  40.   *temp_dirname = temp_dirname_cstr;  
  41.   delete[] temp_dirname_cstr;  
  42. }  
  43. // 从txt读取proto的定义  
  44. bool ReadProtoFromTextFile(const char* filename, Message* proto);  
  45.   
  46. // 从text读取proto的定义  
  47. inline bool ReadProtoFromTextFile(const string& filename, Message* proto) {  
  48.   return ReadProtoFromTextFile(filename.c_str(), proto);  
  49. }  
  50. // 从text读取proto的定义,只是增加了检查而已  
  51. inline void ReadProtoFromTextFileOrDie(const char* filename, Message* proto) {  
  52.   CHECK(ReadProtoFromTextFile(filename, proto));  
  53. }  
  54. // 从text读取proto的定义,只是增加了检查而已  
  55. inline void ReadProtoFromTextFileOrDie(const string& filename, Message* proto) {  
  56.   ReadProtoFromTextFileOrDie(filename.c_str(), proto);  
  57. }  
  58. // 将proto写入到txt文件  
  59. void WriteProtoToTextFile(const Message& proto, const char* filename);  
  60. inline void WriteProtoToTextFile(const Message& proto, const string& filename) {  
  61.   WriteProtoToTextFile(proto, filename.c_str());  
  62. }  
  63. // 从bin读取proto的定义  
  64. bool ReadProtoFromBinaryFile(const char* filename, Message* proto);  
  65. // 从bin读取proto的定义  
  66. inline bool ReadProtoFromBinaryFile(const string& filename, Message* proto) {  
  67.   return ReadProtoFromBinaryFile(filename.c_str(), proto);  
  68. }  
  69. // 从bin读取proto的定义,只是增加了检查而已  
  70. inline void ReadProtoFromBinaryFileOrDie(const char* filename, Message* proto) {  
  71.   CHECK(ReadProtoFromBinaryFile(filename, proto));  
  72. }  
  73. // 从bin读取proto的定义,只是增加了检查而已  
  74. inline void ReadProtoFromBinaryFileOrDie(const string& filename,  
  75.                                          Message* proto) {  
  76.   ReadProtoFromBinaryFileOrDie(filename.c_str(), proto);  
  77. }  
  78.   
  79. // 将proto写入到bin文件  
  80. void WriteProtoToBinaryFile(const Message& proto, const char* filename);  
  81. // 内联函数,将proto写入到bin文件  
  82. inline void WriteProtoToBinaryFile(  
  83.     const Message& proto, const string& filename) {  
  84.   WriteProtoToBinaryFile(proto, filename.c_str());  
  85. }  
  86. // 从文件读取数据到Datum  
  87. bool ReadFileToDatum(const string& filename, const int label, Datum* datum);  
  88. // 内联函数,从文件读取数据到Datum  
  89. inline bool ReadFileToDatum(const string& filename, Datum* datum) {  
  90.   return ReadFileToDatum(filename, -1, datum);  
  91. }  
  92.   
  93. // 从图像文件读取数据到Datum  
  94. bool ReadImageToDatum(const string& filename, const int label,  
  95.     const int height, const int width, const bool is_color,  
  96.     const std::string & encoding, Datum* datum);  
  97. // 内联函数,从图像文件(彩色还是黑白?)读取数据到Datum,指定图像大小  
  98. inline bool ReadImageToDatum(const string& filename, const int label,  
  99.     const int height, const int width, const bool is_color, Datum* datum) {  
  100.   return ReadImageToDatum(filename, label, height, width, is_color,  
  101.                           "", datum);  
  102. }  
  103. // 内联函数,从彩色图像文件读取数据到Datum,指定图像大小  
  104. inline bool ReadImageToDatum(const string& filename, const int label,  
  105.     const int height, const int width, Datum* datum) {  
  106.   return ReadImageToDatum(filename, label, height, width, true, datum);  
  107. }  
  108. // 内联函数,从图像文件(彩色还是黑白?)读取数据到Datum,自动获取图像大小  
  109. inline bool ReadImageToDatum(const string& filename, const int label,  
  110.     const bool is_color, Datum* datum) {  
  111.   return ReadImageToDatum(filename, label, 0, 0, is_color, datum);  
  112. }  
  113. // 内联函数,从彩色图像文件读取数据到Datum,自动获取图像大小  
  114. inline bool ReadImageToDatum(const string& filename, const int label,  
  115.     Datum* datum) {  
  116.   return ReadImageToDatum(filename, label, 0, 0, true, datum);  
  117. }  
  118. // 内联函数,从彩色图像文件读取数据到Datum,自动获取图像大小,指定编码格式  
  119. inline bool ReadImageToDatum(const string& filename, const int label,  
  120.     const std::string & encoding, Datum* datum) {  
  121.   return ReadImageToDatum(filename, label, 0, 0, true, encoding, datum);  
  122. }  
  123. // 对Datum进行解码  
  124. bool DecodeDatumNative(Datum* datum);  
  125. // 对彩色图像的Datum进行解码  
  126. bool DecodeDatum(Datum* datum, bool is_color);  
  127.   
  128. #ifdef USE_OPENCV  
  129. // 将图像读取到CVMat,指定图像大小,是否彩色  
  130. cv::Mat ReadImageToCVMat(const string& filename,  
  131.     const int height, const int width, const bool is_color);  
  132. // 将图像读取到CVMat,指定图像大小  
  133. cv::Mat ReadImageToCVMat(const string& filename,  
  134.     const int height, const int width);  
  135. // 将图像读取到CVMat,指定是否彩色  
  136. cv::Mat ReadImageToCVMat(const string& filename,  
  137.     const bool is_color);  
  138. // 将图像读取到CVMat  
  139. cv::Mat ReadImageToCVMat(const string& filename);  
  140. // 将Datum解码为为CVMat  
  141. cv::Mat DecodeDatumToCVMatNative(const Datum& datum);  
  142. // 将彩色图像的Datum解码为为CVMat  
  143. cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);  
  144. // 将CVMat转换为Datum  
  145. void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);  
  146. #endif  // USE_OPENCV  
  147.   
  148. }  // namespace caffe  
  149.   
  150. #endif   // CAFFE_UTIL_IO_H_  
接下来给出io中的具体的实现的注释
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #include <fcntl.h>  
  2. #include <google/protobuf/io/coded_stream.h>  
  3. #include <google/protobuf/io/zero_copy_stream_impl.h>  
  4. #include <google/protobuf/text_format.h>  
  5. #include <opencv2/core/core.hpp>  
  6. #ifdef USE_OPENCV  
  7. #include <opencv2/highgui/highgui.hpp>  
  8. #include <opencv2/highgui/highgui_c.h>  
  9. #include <opencv2/imgproc/imgproc.hpp>  
  10. #endif  // USE_OPENCV  
  11. #include <stdint.h>  
  12.   
  13. #include <algorithm>  
  14. #include <fstream>  // NOLINT(readability/streams)  
  15. #include <string>  
  16. #include <vector>  
  17.   
  18. #include "caffe/common.hpp"  
  19. #include "caffe/proto/caffe.pb.h"  
  20. #include "caffe/util/io.hpp"  
  21.   
  22. const int kProtoReadBytesLimit = INT_MAX;  // Max size of 2 GB minus 1 byte.  
  23.   
  24. namespace caffe {  
  25.   
  26. using google::protobuf::io::FileInputStream;  
  27. using google::protobuf::io::FileOutputStream;  
  28. using google::protobuf::io::ZeroCopyInputStream;  
  29. using google::protobuf::io::CodedInputStream;  
  30. using google::protobuf::io::ZeroCopyOutputStream;  
  31. using google::protobuf::io::CodedOutputStream;  
  32. using google::protobuf::Message;  
  33.   
  34. // 从文件读取Proto的txt文件  
  35. bool ReadProtoFromTextFile(const char* filename, Message* proto) {  
  36.   int fd = open(filename, O_RDONLY);  
  37.   CHECK_NE(fd, -1) << "File not found: " << filename;  
  38.   FileInputStream* input = new FileInputStream(fd);  
  39.   // 注意如何使用protobuf去读取  
  40.   bool success = google::protobuf::TextFormat::Parse(input, proto);  
  41.   delete input;  
  42.   close(fd);  
  43.   return success;  
  44. }  
  45.   
  46. // 将proto写入到txt文件  
  47. void WriteProtoToTextFile(const Message& proto, const char* filename) {  
  48.   int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);  
  49.   FileOutputStream* output = new FileOutputStream(fd);  
  50.   // 注意如何写入  
  51.   CHECK(google::protobuf::TextFormat::Print(proto, output));  
  52.   delete output;  
  53.   close(fd);  
  54. }  
  55.   
  56. // 从bin读取proto的定义  
  57. bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {  
  58.   int fd = open(filename, O_RDONLY);  
  59.   CHECK_NE(fd, -1) << "File not found: " << filename;  
  60.   ZeroCopyInputStream* raw_input = new FileInputStream(fd);  
  61.   //  解码流com.google.protobuf.CodedInputStream  
  62.   CodedInputStream* coded_input = new CodedInputStream(raw_input);  
  63.   coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);  
  64.   
  65.   bool success = proto->ParseFromCodedStream(coded_input);  
  66.   
  67.   delete coded_input;  
  68.   delete raw_input;  
  69.   close(fd);  
  70.   return success;  
  71. }  
  72.   
  73. // 将proto写入到bin文件  
  74. void WriteProtoToBinaryFile(const Message& proto, const char* filename) {  
  75.   fstream output(filename, ios::out | ios::trunc | ios::binary);  
  76.   CHECK(proto.SerializeToOstream(&output));  
  77. }  
  78.   
  79. #ifdef USE_OPENCV  
  80. // 将图像读取到CVMat,指定图像大小,是否彩色  
  81. cv::Mat ReadImageToCVMat(const string& filename,  
  82.     const int height, const int width, const bool is_color) {  
  83.   cv::Mat cv_img;  
  84.   int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :  
  85.     CV_LOAD_IMAGE_GRAYSCALE);  
  86.   cv::Mat cv_img_origin = cv::imread(filename, cv_read_flag);  
  87.   if (!cv_img_origin.data) {  
  88.     LOG(ERROR) << "Could not open or find file " << filename;  
  89.     return cv_img_origin;  
  90.   }  
  91.   if (height > 0 && width > 0) {  
  92.     cv::resize(cv_img_origin, cv_img, cv::Size(width, height));  
  93.   } else {  
  94.     cv_img = cv_img_origin;  
  95.   }  
  96.   return cv_img;  
  97. }  
  98.   
  99. cv::Mat ReadImageToCVMat(const string& filename,  
  100.     const int height, const int width) {  
  101.   return ReadImageToCVMat(filename, height, width, true);  
  102. }  
  103.   
  104. cv::Mat ReadImageToCVMat(const string& filename,  
  105.     const bool is_color) {  
  106.   return ReadImageToCVMat(filename, 0, 0, is_color);  
  107. }  
  108.   
  109. cv::Mat ReadImageToCVMat(const string& filename) {  
  110.   return ReadImageToCVMat(filename, 0, 0, true);  
  111. }  
  112.   
  113. // Do the file extension and encoding match?  
  114. // 看看是不是jpg还是jpeg的图像  
  115. static bool matchExt(const std::string & fn,  
  116.                      std::string en) {  
  117.   size_t p = fn.rfind('.');  
  118.   std::string ext = p != fn.npos ? fn.substr(p) : fn;  
  119.   std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);  
  120.   std::transform(en.begin(), en.end(), en.begin(), ::tolower);  
  121.   if ( ext == en )  
  122.     return true;  
  123.   if ( en == "jpg" && ext == "jpeg" )  
  124.     return true;  
  125.   return false;  
  126. }  
  127. // 从图像文件读取数据到Datum  
  128. bool ReadImageToDatum(const string& filename, const int label,  
  129.     const int height, const int width, const bool is_color,  
  130.     const std::string & encoding, Datum* datum) {  
  131.   cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);  
  132.   if (cv_img.data) {  
  133.     if (encoding.size()) {  
  134.       if ( (cv_img.channels() == 3) == is_color && !height && !width &&  
  135.           matchExt(filename, encoding) )  
  136.         return ReadFileToDatum(filename, label, datum);  
  137.       std::vector<uchar> buf;  
  138.       // 对数据解码  
  139.       cv::imencode("."+encoding, cv_img, buf);  
  140.       datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),  
  141.                       buf.size()));  
  142.       // 数据标签  
  143.       datum->set_label(label);  
  144.       // 是否被编码  
  145.       datum->set_encoded(true);  
  146.       return true;  
  147.     }  
  148.     CVMatToDatum(cv_img, datum);  
  149.     datum->set_label(label);  
  150.     return true;  
  151.   } else {  
  152.     return false;  
  153.   }  
  154. }  
  155. #endif  // USE_OPENCV  
  156. // 从文件读取数据到Datum  
  157. bool ReadFileToDatum(const string& filename, const int label,  
  158.     Datum* datum) {  
  159.   std::streampos size;  
  160.   
  161.   fstream file(filename.c_str(), ios::in|ios::binary|ios::ate);  
  162.   if (file.is_open()) {  
  163.     size = file.tellg();  
  164.     std::string buffer(size, ' ');  
  165.     file.seekg(0, ios::beg);  
  166.     file.read(&buffer[0], size);  
  167.     file.close();  
  168.     datum->set_data(buffer);  
  169.     datum->set_label(label);  
  170.     datum->set_encoded(true);  
  171.     return true;  
  172.   } else {  
  173.     return false;  
  174.   }  
  175. }  
  176.   
  177. #ifdef USE_OPENCV  
  178. // 直接编码数据的Datum到CVMat  
  179. cv::Mat DecodeDatumToCVMatNative(const Datum& datum) {  
  180.   cv::Mat cv_img;  
  181.   CHECK(datum.encoded()) << "Datum not encoded";  
  182.   const string& data = datum.data();  
  183.   std::vector<char> vec_data(data.c_str(), data.c_str() + data.size());  
  184.   cv_img = cv::imdecode(vec_data, -1);//flag=-1  
  185.   if (!cv_img.data) {  
  186.     LOG(ERROR) << "Could not decode datum ";  
  187.   }  
  188.   return cv_img;  
  189. }  
  190.   
  191. // 直接编码彩色或者非彩色Datum到CVMat  
  192. cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color) {  
  193.   cv::Mat cv_img;  
  194.   CHECK(datum.encoded()) << "Datum not encoded";  
  195.   const string& data = datum.data();  
  196.   std::vector<char> vec_data(data.c_str(), data.c_str() + data.size());  
  197.   int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :  
  198.     CV_LOAD_IMAGE_GRAYSCALE);  
  199.   cv_img = cv::imdecode(vec_data, cv_read_flag);// flag为用户指定的  
  200.   if (!cv_img.data) {  
  201.     LOG(ERROR) << "Could not decode datum ";  
  202.   }  
  203.   return cv_img;  
  204. }  
  205.   
  206. // If Datum is encoded will decoded using DecodeDatumToCVMat and CVMatToDatum  
  207. // If Datum is not encoded will do nothing  
  208. bool DecodeDatumNative(Datum* datum) {  
  209.   if (datum->encoded()) {  
  210.     cv::Mat cv_img = DecodeDatumToCVMatNative((*datum));  
  211.     CVMatToDatum(cv_img, datum);  
  212.     return true;  
  213.   } else {  
  214.     return false;  
  215.   }  
  216. }  
  217.   
  218. // 将Datum进行解码  
  219. bool DecodeDatum(Datum* datum, bool is_color) {  
  220.   if (datum->encoded()) {  
  221.     cv::Mat cv_img = DecodeDatumToCVMat((*datum), is_color);  
  222.     CVMatToDatum(cv_img, datum);  
  223.     return true;  
  224.   } else {  
  225.     return false;  
  226.   }  
  227. }  
  228.   
  229. // 将CVMat转换到Datum  
  230. void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) {  
  231.   CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";  
  232.   datum->set_channels(cv_img.channels());  
  233.   datum->set_height(cv_img.rows);  
  234.   datum->set_width(cv_img.cols);  
  235.   datum->clear_data();  
  236.   datum->clear_float_data();  
  237.   datum->set_encoded(false);  
  238.   int datum_channels = datum->channels();  
  239.   int datum_height = datum->height();  
  240.   int datum_width = datum->width();  
  241.   int datum_size = datum_channels * datum_height * datum_width;  
  242.   std::string buffer(datum_size, ' ');  
  243.   for (int h = 0; h < datum_height; ++h) {  
  244.     const uchar* ptr = cv_img.ptr<uchar>(h);  
  245.     int img_index = 0;  
  246.     for (int w = 0; w < datum_width; ++w) {  
  247.       for (int c = 0; c < datum_channels; ++c) {  
  248.         int datum_index = (c * datum_height + h) * datum_width + w;  
  249.         buffer[datum_index] = static_cast<char>(ptr[img_index++]);  
  250.       }  
  251.     }  
  252.   }  
  253.   datum->set_data(buffer);  
  254. }  
  255. #endif  // USE_OPENCV  
  256. }  // namespace caffe  

四、总结

总结起来就是,DataTransformer所作的工作实际上就是crop数据,让数据减去均值,以及缩放数据。
然后就是根据数据来推断形状。此外还介绍了io的内容,里面包含了创建临时文件临时目录操作,以及从txt文件以及bin文件读取proto数据或者写入proto的数据到txt或者bin文件。


参考:

[1]你可能需要了解关于cv::imencode和 cv::imdecode函数的flag的含义

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值