四、数据转换器
主要用于对原始数据预处理的方法,包括:随机切块、随机镜像、幅度缩放、去均值和灰度/色度变换等。数据转换器的类为
DataTransformer
数据转换功能主要由Transform()函数实现,切块的大小、幅度缩放、随机镜像、去均值等功能,主要由以下5种函数实现函数重载。
Datum数据结构主要用于从LMDB/LEVELDB读取数据,或者写入数据,是专门为数据或特征图提供序列化或反序列化的功能
主要与以下几种函数
主要用于对原始数据预处理的方法,包括:随机切块、随机镜像、幅度缩放、去均值和灰度/色度变换等。数据转换器的类为
DataTransformer
成员变量
TransformationParameter param_;//变换参数,该参数由ProtoBuffer生成
shared_ptr<Caffe::RNG> rng_;//随机数生成器
Phase phase_;//当前的运行阶段
Blob<Dtype> data_mean_;//均值图像,从均值文件中读取
vector<Dtype> mean_values_;//均值数值,从param_中获取
1、构造函数
template<typename Dtype>
DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,
Phase phase) : param_(param), phase_(phase) {
if (param_.has_mean_file()) {//查看是否使用均值文件
CHECK_EQ(param_.mean_value_size(), 0) <<"Cannot specify mean_file and mean_value at the same time";//不能同时指定均值文件和均值
const string& mean_file = param.mean_file();//获取均值文件名
if (Caffe::root_solver()) {
LOG(INFO) << "Loading mean file from: " << mean_file;
}
BlobProto blob_proto;
ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);//从均值文件中读取数据到blob_proto
data_mean_.FromProto(blob_proto);//将blob_proto将均值反序列化到data_mean_
}
if (param_.mean_value_size() > 0) {//查看是否使用均值
CHECK(param_.has_mean_file() == false) << "Cannot specify mean_file and mean_value at the same time";
for (int c = 0; c < param_.mean_value_size(); ++c) {
mean_values_.push_back(param_.mean_value(c));
}
}
}
virtual ~DataTransformer() {}
2、数据转换
数据转换功能主要由Transform()函数实现,切块的大小、幅度缩放、随机镜像、去均值等功能,主要由以下5种函数实现函数重载。
Datum数据结构主要用于从LMDB/LEVELDB读取数据,或者写入数据,是专门为数据或特征图提供序列化或反序列化的功能
//输入为Datum输出为数据指针
void Transform(const Datum& datum, Blob<Dtype>* transformed_blob);
//输出为Blob
void Transform(const vector<Datum> & datum_vector,Blob<Dtype>* transformed_blob);
//将Mat转化为Blob
void Transform(const vector<cv::Mat> & mat_vector,Blob<Dtype>* transformed_blob);
void Transform(const cv::Mat& cv_img, Blob<Dtype>* transformed_blob);
//blob转化为Blob
void Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob);
因为上面几种函数大体的处理过程相同,在这里这看了第一种函数
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum,Dtype* transformed_data) {//图像数据转化
//获取Datum数据信息图像通道、width和height
const string& data = datum.data();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
//从 param_获取变换参数,
const int crop_size = param_.crop_size();//切块的大小
const Dtype scale = param_.scale();//幅度缩放
const bool do_mirror = param_.mirror() && Rand(2);//随机镜像
const bool has_mean_file = param_.has_mean_file();//去均值
const bool has_uint8 = data.size() > 0;
const bool has_mean_values = mean_values_.size() > 0;
//图像检查
CHECK_GT(datum_channels, 0);
CHECK_GE(datum_height, crop_size);
CHECK_GE(datum_width, crop_size);
//获取均值
Dtype* mean = NULL;
if (has_mean_file) {//如果指定均值文件
CHECK_EQ(datum_channels, data_mean_.channels());
CHECK_EQ(datum_height, data_mean_.height());
CHECK_EQ(datum_width, data_mean_.width());
mean = data_mean_.mutable_cpu_data();
}
if (has_mean_values) {//如果没有指定均值文件,直接获取均值
CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<
"Specify either 1 mean_value or as many as channels: " << datum_channels;
if (datum_channels > 1 && mean_values_.size() == 1) {// 如果均值维度为1,二图片通道大于1,则重复该均值
for (int c = 1; c < datum_channels; ++c) {
mean_values_.push_back(mean_values_[0]);
}
}
}
int height = datum_height;
int width = datum_width;
//图像切块
int h_off = 0;
int w_off = 0;
if (crop_size) {
height = crop_size;
width = crop_size;
// 随机切块
if (phase_ == TRAIN) {
h_off = Rand(datum_height - crop_size + 1);
w_off = Rand(datum_width - crop_size + 1);
} else {//测试阶段只取图像中心位置
h_off = (datum_height - crop_size) / 2;
w_off = (datum_width - crop_size) / 2;
}
}
Dtype datum_element;//存放输入图像的像素值
int top_index, data_index;//输入输出index
for (int c = 0; c < datum_channels; ++c) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
if (do_mirror) {
top_index = (c * height + h) * width + (width - 1 - w);
} else {
top_index = (c * height + h) * width + w;
}
if (has_uint8) {//如果datum中使用uint8存储数据,转化为float
datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
} else {
datum_element = datum.float_data(data_index);
}
if (has_mean_file) {//如果指定均值文件
transformed_data[top_index] =
(datum_element - mean[data_index]) * scale;
} else {
if (has_mean_values) {
transformed_data[top_index] =
(datum_element - mean_values_[c]) * scale;
} else {
transformed_data[top_index] = datum_element * scale;
}
}
}
}
}
}
3、获得输出Blob的尺寸
主要与以下几种函数
vector<int> InferBlobShape(const Datum& datum);
vector<int> InferBlobShape(const vector<Datum> & datum_vector);
vector<int> InferBlobShape(const vector<cv::Mat> & mat_vector);
vector<int> InferBlobShape(const cv::Mat& cv_img);
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {
if (datum.encoded()) {//如果datum是经过编码的图像,需要先解码
#ifdef USE_OPENCV
CHECK(!(param_.force_color() && param_.force_gray()))<< "cannot set both force_color and force_gray";
cv::Mat cv_img;
if (param_.force_color() || param_.force_gray()) {//彩色图像
cv_img = DecodeDatumToCVMat(datum, param_.force_color());
} else {
cv_img = DecodeDatumToCVMatNative(datum);
}
return InferBlobShape(cv_img);
#else
LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
#endif // USE_OPENCV
}
const int crop_size = param_.crop_size();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
// Check dimensions.
CHECK_GT(datum_channels, 0);
CHECK_GE(datum_height, crop_size);
CHECK_GE(datum_width, crop_size);
// Build BlobShape.
vector<int> shape(4);
shape[0] = 1;
shape[1] = datum_channels;
shape[2] = (crop_size)? crop_size: datum_height;
shape[3] = (crop_size)? crop_size: datum_width;
return shape;
}
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) {
const int crop_size = param_.crop_size();
const int img_channels = cv_img.channels();
const int img_height = cv_img.rows;
const int img_width = cv_img.cols;
// 检测维度
CHECK_GT(img_channels, 0);
CHECK_GE(img_height, crop_size);
CHECK_GE(img_width, crop_size);
// 创建 BlobShape.
vector<int> shape(4);
shape[0] = 1;
shape[1] = img_channels;
shape[2] = (crop_size)? crop_size: img_height;
shape[3] = (crop_size)? crop_size: img_width;
return shape;
}
4、初始化随机种子
template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {
//如果在初始化操作中要求随机镜像或训练阶段需要随机切片,那么需要初始化随机种子
const bool needs_rand = param_.mirror() ||
(phase_ == TRAIN && param_.crop_size());
if (needs_rand) {
const unsigned int rng_seed = caffe_rng_rand();
rng_.reset(new Caffe::RNG(rng_seed));
} else {
rng_.reset();
}
}