这里我们分析一下caffe源码中提供的一个计算图片均值的小工具,源码为caffe/tools/compute_image_mean.cpp
首先介绍一下里面用到的struct:
1.BlobProto sum_blob; blob类是caffe中使用的基础的数据类,利用protobuf工具来自动生成的,定义在/caffe/src/proto/caffe.proto文件中,我们可以看到如下内容:
message BlobProto {
optional BlobShape shape = 7;
repeated float data = 5 [packed = true];
repeated float diff = 6 [packed = true];
repeated double double_data = 8 [packed = true];
repeated double double_diff = 9 [packed = true];
// 4D dimensions -- deprecated. Use "shape" instead.
optional int32 num = 1 [default = 0];
optional int32 channels = 2 [default = 0];
optional int32 height = 3 [default = 0];
optional int32 width = 4 [default = 0];
}
通过上面的定义我们可以看到,这里面定义了维度信息,channel x height x width可以代表一张图片, num代表图片的数量。前面几行是定义了数据,里面可以存放图片的像素值,也可以存放diff value,即可以放float值,也可以放double值。所以说,一个BlobProto结构体可以存放多张图片,也就是一批图片,所有图片的数据则是按顺序存放在一个数组中。这个struct还可以用于和文件的交互,也就是利用protobuf工具将BlobProto中的数据序列化之后存入文件,也可以从文件中读取数据来初始化一个BlobProto对象,这样就可以将training的结果存入文件,下次使用时从文件中载入。
2,Datum: 一张图片
也是在/caffe/src/proto/caffe.proto中定义的,用来代表一张图片,也就是没有num这个变量,定义如下。
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
// the actual image data, in bytes
optional bytes data = 4;
optional int32 label = 5;
// Optionally, the datum could also hold float data.
repeated float float_data = 6;
// If true data contains an encoded image that need to be decoded
optional bool encoded = 7 [default = false];
}
Datum可以和存储图片数据的数据库交互,在下面的例子中就是图片的db中的存储方式为(key, value)方式,一个key对应一张图片value,Datum提供了接口利用value(字符数组格式)来初始化一个Datum。这样就可以从数据库中依次读取每张图片来做后续处理。在io.cpp中还提供了一些辅助函数来帮助和文件的交互。比如下面用到的DecodeDatumNative(&datum);就是利用的opencv的encode/decode函数来对Datum中的数据进行编码或者解码。在Datum中有一个flag来表示内部的数据是否是编码的数据。
3. lmdb : 存储数据的数据库
文件location : caffe/include/caffe/util/db.hpp, db_lmdb.hpp, db_leveldb.hpp caffe/src/caffe/util/db.cpp, db_lmdb.cpp db_leveldb.cpp
在caffe中使用了两种数据库leveldb和lmdb, 一般使用lmdb。这里主要介绍lmdb。
db.hpp中提供了create函数,根据传入的数据库的类别创建对应的数据库的对象,作为数据库的基类使用。
在db_lmdb.h中创建了三个类,LMDBCursor 数据库的游标,用来索引数据库中的各个元素,并获得所指向的位置的key和value值。LMDBTransaction实现将数据put进数据库和commit到数据库。put只是提交到数据库,没有murge到主干上。commit命令是将前面put的数据murge到主干上。
class LMDB : 提供对数据库的打开关闭等,并在内部创建上面两个类的对象,提供接口让外部或者上面两个类的对象的指针,让app可以操作数据库。因为在caffe中只是用数据库来存放数据,没有用到复杂的功能,所以这几天类都比较简单,提供基本的功能。
下面就是compute_image_mean.cpp的源码分析了:
#include <stdint.h>
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include "boost/scoped_ptr.hpp" //boost的智能指针
#include "gflags/gflags.h" //负责输入参数的解析
#include "glog/logging.h" //负责log记录及输出
#include "caffe/proto/caffe.pb.h"//caffe的proto文件,里面记录了定义的各种需要的proto struct
#include "caffe/util/db.hpp" // 存储image的数据库
#include "caffe/util/io.hpp" //与文件io的一些函数
using namespace caffe; // NOLINT(build/namespaces)
using std::max;
using std::pair;
using boost::scoped_ptr;
DEFINE_string(backend, "lmdb",
"The backend {leveldb, lmdb} containing the images");
int main(int argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
// Print output to stderr (while still logging)
FLAGS_alsologtostderr = 1;
#ifndef GFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
gflags::SetUsageMessage("Compute the mean_image of a set of images given by"
" a leveldb/lmdb\n"
"Usage:\n"
" compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE]\n");
gflags::ParseCommandLineFlags(&argc, &argv, true); //利用gflags解析输入的命令行
if (argc < 2 || argc > 3) {
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/compute_image_mean");
return 1;
}
//db基类提供的create函数,利用传入的backend来选择创建LMDB还是LEVELDBD对象.
scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
db->Open(argv[1], db::READ); //打开指定的文件,以只读方式打开
scoped_ptr<db::Cursor> cursor(db->NewCursor()); //获得该数据库的游标Cursor。
//在caffe.proto中定义的struct,里面包含了width, height, channel, num四个维度,和放data的buffer,
//所以一个BlobProto可以存放多张图片,也就是一批图片
BlobProto sum_blob;
int count = 0;
// load first datum
//也是在caffe.proto中定义的struct,只有三个维度width, height, channel,也就是一张图片的数据。
//里面有一个flag来表示数据是否编码过。在io.cpp中提供了对Datum中的数据编码的函数,利用了opencv的encode函数。
Datum datum;
//从一个string中获得data, 在数据库中存储的数据格式为(key, value),value就是一张图片的数据
datum.ParseFromString(cursor->value());
if (DecodeDatumNative(&datum)) { //io中提供的对datum decode的函数,利用了opencv的encode/decode函数
LOG(INFO) << "Decoding Datum";
}
sum_blob.set_num(1); //因为sum_blob是放sum的i地方,所以只有一张图片的数据,num设置为1
sum_blob.set_channels(datum.channels()); //图片的channel
sum_blob.set_height(datum.height()); //图片的height
sum_blob.set_width(datum.width()); //图片的width
const int data_size = datum.channels() * datum.height() * datum.width(); //图片真实数据的长度,w x h x channel
int size_in_datum = std::max<int>(datum.data().size(), datum.float_data_size()); //数据可能是uint8,也可能是float
for (int i = 0; i < size_in_datum; ++i) {
sum_blob.add_data(0.); //每个元素初始化为0
}
LOG(INFO) << "Starting iteration";
while (cursor->valid()) { //该cursor所在的位置valid,一直到所有数据结束
Datum datum;
datum.ParseFromString(cursor->value()); // 将cursor中的value,也就是一张图片放到datum中
DecodeDatumNative(&datum); //如果数据encode过后,需要decode
const std::string& data = datum.data(); //datum的内部data是用char数组来存储的
size_in_datum = std::max<int>(datum.data().size(),
datum.float_data_size());
CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
size_in_datum;
if (data.size() != 0) { //如果数据是uint8
CHECK_EQ(data.size(), size_in_datum);
for (int i = 0; i < size_in_datum; ++i) {
sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]); // 计算各个对象像素的值的和,以便于求均值
}
} else { // 数据是float
CHECK_EQ(datum.float_data_size(), size_in_datum);
for (int i = 0; i < size_in_datum; ++i) {
sum_blob.set_data(i, sum_blob.data(i) +
static_cast<float>(datum.float_data(i)));
}
}
++count;
if (count % 10000 == 0) { // 每隔10000个图片打出一条log
LOG(INFO) << "Processed " << count << " files.";
}
cursor->Next(); //cursor指向下一个数据
}
if (count % 10000 != 0) {
LOG(INFO) << "Processed " << count << " files.";
}
for (int i = 0; i < sum_blob.data_size(); ++i) {
sum_blob.set_data(i, sum_blob.data(i) / count); // 每个sum除以count来计算均数
}
// Write to disk
if (argc == 3) {
LOG(INFO) << "Write to " << argv[2];
WriteProtoToBinaryFile(sum_blob, argv[2]); //io提供的函数来利用protobuf库来将数据写入文件
}
const int channels = sum_blob.channels();
const int dim = sum_blob.height() * sum_blob.width();
std::vector<float> mean_values(channels, 0.0);
LOG(INFO) << "Number of channels: " << channels;
for (int c = 0; c < channels; ++c) {
for (int i = 0; i < dim; ++i) {
mean_values[c] += sum_blob.data(dim * c + i); //这里计算每个channel的均值,也就是一张图片20x20, 400个像素的均值
}
LOG(INFO) << "mean_value channel [" << c << "]: " << mean_values[c] / dim;
}
return 0;
}