caffe源码阅读之layer(2)——DataLayer层(2)

四、数据转换器
主要用于对原始数据预处理的方法,包括:随机切块、随机镜像、幅度缩放、去均值和灰度/色度变换等。数据转换器的类为
DataTransformer

成员变量

  TransformationParameter param_;//变换参数,该参数由ProtoBuffer生成
  shared_ptr<Caffe::RNG> rng_;//随机数生成器
  Phase phase_;//当前的运行阶段
  Blob<Dtype> data_mean_;//均值图像,从均值文件中读取
  vector<Dtype> mean_values_;//均值数值,从param_中获取
1、构造函数

template<typename Dtype>
DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,
    Phase phase) : param_(param), phase_(phase) {
  if (param_.has_mean_file()) {//查看是否使用均值文件
    CHECK_EQ(param_.mean_value_size(), 0) <<"Cannot specify mean_file and mean_value at the same time";//不能同时指定均值文件和均值
    const string& mean_file = param.mean_file();//获取均值文件名
    if (Caffe::root_solver()) {
      LOG(INFO) << "Loading mean file from: " << mean_file;
    }
    BlobProto blob_proto;
    ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);//从均值文件中读取数据到blob_proto
    data_mean_.FromProto(blob_proto);//将blob_proto将均值反序列化到data_mean_
  }
  if (param_.mean_value_size() > 0) {//查看是否使用均值
    CHECK(param_.has_mean_file() == false) << "Cannot specify mean_file and mean_value at the same time";
    for (int c = 0; c < param_.mean_value_size(); ++c) {
      mean_values_.push_back(param_.mean_value(c));
    }
  }
}
  virtual ~DataTransformer() {}
2、数据转换
数据转换功能主要由Transform()函数实现,切块的大小、幅度缩放、随机镜像、去均值等功能,主要由以下5种函数实现函数重载。
Datum数据结构主要用于从LMDB/LEVELDB读取数据,或者写入数据,是专门为数据或特征图提供序列化或反序列化的功能
//输入为Datum输出为数据指针
 void Transform(const Datum& datum, Blob<Dtype>* transformed_blob);
 //输出为Blob
 void Transform(const vector<Datum> & datum_vector,Blob<Dtype>* transformed_blob);
 //将Mat转化为Blob
 void Transform(const vector<cv::Mat> & mat_vector,Blob<Dtype>* transformed_blob);
 void Transform(const cv::Mat& cv_img, Blob<Dtype>* transformed_blob);
 //blob转化为Blob
 void Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob);
因为上面几种函数大体的处理过程相同,在这里这看了第一种函数
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum,Dtype* transformed_data) {//图像数据转化
  //获取Datum数据信息图像通道、width和height
  const string& data = datum.data();
  const int datum_channels = datum.channels();
  const int datum_height = datum.height();
  const int datum_width = datum.width();
//从 param_获取变换参数,
  const int crop_size = param_.crop_size();//切块的大小
  const Dtype scale = param_.scale();//幅度缩放
  const bool do_mirror = param_.mirror() && Rand(2);//随机镜像
  const bool has_mean_file = param_.has_mean_file();//去均值
  const bool has_uint8 = data.size() > 0;
  const bool has_mean_values = mean_values_.size() > 0;
  //图像检查
  CHECK_GT(datum_channels, 0);
  CHECK_GE(datum_height, crop_size);
  CHECK_GE(datum_width, crop_size);
//获取均值
  Dtype* mean = NULL;
  if (has_mean_file) {//如果指定均值文件
    CHECK_EQ(datum_channels, data_mean_.channels());
    CHECK_EQ(datum_height, data_mean_.height());
    CHECK_EQ(datum_width, data_mean_.width());
    mean = data_mean_.mutable_cpu_data();
  }
  if (has_mean_values) {//如果没有指定均值文件,直接获取均值
    CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<
     "Specify either 1 mean_value or as many as channels: " << datum_channels;
    if (datum_channels > 1 && mean_values_.size() == 1) {// 如果均值维度为1,二图片通道大于1,则重复该均值
      for (int c = 1; c < datum_channels; ++c) {
        mean_values_.push_back(mean_values_[0]);
      }
    }
  }
  int height = datum_height;
  int width = datum_width;
//图像切块
  int h_off = 0;
  int w_off = 0;
  if (crop_size) {
    height = crop_size;
    width = crop_size;
    // 随机切块
    if (phase_ == TRAIN) {
      h_off = Rand(datum_height - crop_size + 1);
      w_off = Rand(datum_width - crop_size + 1);
    } else {//测试阶段只取图像中心位置
      h_off = (datum_height - crop_size) / 2;
      w_off = (datum_width - crop_size) / 2;
    }
  }
  Dtype datum_element;//存放输入图像的像素值
  int top_index, data_index;//输入输出index
  for (int c = 0; c < datum_channels; ++c) {
    for (int h = 0; h < height; ++h) {
      for (int w = 0; w < width; ++w) {
        data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
        if (do_mirror) {
          top_index = (c * height + h) * width + (width - 1 - w);
        } else {
          top_index = (c * height + h) * width + w;
        }
        if (has_uint8) {//如果datum中使用uint8存储数据,转化为float
          datum_element =
            static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
        } else {
          datum_element = datum.float_data(data_index);
        }
        if (has_mean_file) {//如果指定均值文件
          transformed_data[top_index] =
            (datum_element - mean[data_index]) * scale;
        } else {
          if (has_mean_values) {
            transformed_data[top_index] =
              (datum_element - mean_values_[c]) * scale;
          } else {
            transformed_data[top_index] = datum_element * scale;
          }
        }
      }
    }
  }
}
3、获得输出Blob的尺寸
主要与以下几种函数
vector<int> InferBlobShape(const Datum& datum);
vector<int> InferBlobShape(const vector<Datum> & datum_vector);
 vector<int> InferBlobShape(const vector<cv::Mat> & mat_vector);
 vector<int> InferBlobShape(const cv::Mat& cv_img);
 
 template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {
  if (datum.encoded()) {//如果datum是经过编码的图像,需要先解码
#ifdef USE_OPENCV
    CHECK(!(param_.force_color() && param_.force_gray()))<< "cannot set both force_color and force_gray";
    cv::Mat cv_img;
    if (param_.force_color() || param_.force_gray()) {//彩色图像
      cv_img = DecodeDatumToCVMat(datum, param_.force_color());
    } else {
      cv_img = DecodeDatumToCVMatNative(datum);
    }
    return InferBlobShape(cv_img);
#else
    LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
#endif  // USE_OPENCV
  }
  const int crop_size = param_.crop_size();
  const int datum_channels = datum.channels();
  const int datum_height = datum.height();
  const int datum_width = datum.width();
  // Check dimensions.
  CHECK_GT(datum_channels, 0);
  CHECK_GE(datum_height, crop_size);
  CHECK_GE(datum_width, crop_size);
  // Build BlobShape.
  vector<int> shape(4);
  shape[0] = 1;
  shape[1] = datum_channels;
  shape[2] = (crop_size)? crop_size: datum_height;
  shape[3] = (crop_size)? crop_size: datum_width;
  return shape;
}
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) {
  const int crop_size = param_.crop_size();
  const int img_channels = cv_img.channels();
  const int img_height = cv_img.rows;
  const int img_width = cv_img.cols;
  // 检测维度
  CHECK_GT(img_channels, 0);
  CHECK_GE(img_height, crop_size);
  CHECK_GE(img_width, crop_size);
  // 创建 BlobShape.
  vector<int> shape(4);
  shape[0] = 1;
  shape[1] = img_channels;
  shape[2] = (crop_size)? crop_size: img_height;
  shape[3] = (crop_size)? crop_size: img_width;
  return shape;
}
4、初始化随机种子
template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {
//如果在初始化操作中要求随机镜像或训练阶段需要随机切片,那么需要初始化随机种子
  const bool needs_rand = param_.mirror() ||
      (phase_ == TRAIN && param_.crop_size());
  if (needs_rand) {
    const unsigned int rng_seed = caffe_rng_rand();
    rng_.reset(new Caffe::RNG(rng_seed));
  } else {
    rng_.reset();
  }
}



  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值