caffe tool之compute_image_mean源码解析

这里我们分析一下caffe源码中提供的一个计算图片均值的小工具,源码为caffe/tools/compute_image_mean.cpp
首先介绍一下里面用到的struct:
1.BlobProto sum_blob; blob类是caffe中使用的基础的数据类,利用protobuf工具来自动生成的,定义在/caffe/src/proto/caffe.proto文件中,我们可以看到如下内容:

message BlobProto {
  optional BlobShape shape = 7;
  repeated float data = 5 [packed = true];
  repeated float diff = 6 [packed = true];
  repeated double double_data = 8 [packed = true];
  repeated double double_diff = 9 [packed = true];

  // 4D dimensions -- deprecated.  Use "shape" instead.
  optional int32 num = 1 [default = 0];
  optional int32 channels = 2 [default = 0];
  optional int32 height = 3 [default = 0];
  optional int32 width = 4 [default = 0];
}

通过上面的定义我们可以看到,这里面定义了维度信息,channel x height x width可以代表一张图片, num代表图片的数量。前面几行是定义了数据,里面可以存放图片的像素值,也可以存放diff value,即可以放float值,也可以放double值。所以说,一个BlobProto结构体可以存放多张图片,也就是一批图片,所有图片的数据则是按顺序存放在一个数组中。这个struct还可以用于和文件的交互,也就是利用protobuf工具将BlobProto中的数据序列化之后存入文件,也可以从文件中读取数据来初始化一个BlobProto对象,这样就可以将training的结果存入文件,下次使用时从文件中载入。

2,Datum: 一张图片
也是在/caffe/src/proto/caffe.proto中定义的,用来代表一张图片,也就是没有num这个变量,定义如下。

message Datum {
  optional int32 channels = 1;
  optional int32 height = 2;
  optional int32 width = 3;
  // the actual image data, in bytes
  optional bytes data = 4;
  optional int32 label = 5;
  // Optionally, the datum could also hold float data.
  repeated float float_data = 6;
  // If true data contains an encoded image that need to be decoded
  optional bool encoded = 7 [default = false];
}

Datum可以和存储图片数据的数据库交互,在下面的例子中就是图片的db中的存储方式为(key, value)方式,一个key对应一张图片value,Datum提供了接口利用value(字符数组格式)来初始化一个Datum。这样就可以从数据库中依次读取每张图片来做后续处理。在io.cpp中还提供了一些辅助函数来帮助和文件的交互。比如下面用到的DecodeDatumNative(&datum);就是利用的opencv的encode/decode函数来对Datum中的数据进行编码或者解码。在Datum中有一个flag来表示内部的数据是否是编码的数据。
3. lmdb : 存储数据的数据库
文件location : caffe/include/caffe/util/db.hpp, db_lmdb.hpp, db_leveldb.hpp caffe/src/caffe/util/db.cpp, db_lmdb.cpp db_leveldb.cpp
在caffe中使用了两种数据库leveldb和lmdb, 一般使用lmdb。这里主要介绍lmdb。
db.hpp中提供了create函数,根据传入的数据库的类别创建对应的数据库的对象,作为数据库的基类使用。
在db_lmdb.h中创建了三个类,LMDBCursor 数据库的游标,用来索引数据库中的各个元素,并获得所指向的位置的key和value值。LMDBTransaction实现将数据put进数据库和commit到数据库。put只是提交到数据库,没有murge到主干上。commit命令是将前面put的数据murge到主干上。
class LMDB : 提供对数据库的打开关闭等,并在内部创建上面两个类的对象,提供接口让外部或者上面两个类的对象的指针,让app可以操作数据库。因为在caffe中只是用数据库来存放数据,没有用到复杂的功能,所以这几天类都比较简单,提供基本的功能。

下面就是compute_image_mean.cpp的源码分析了:

#include <stdint.h>
#include <algorithm>
#include <string>
#include <utility>
#include <vector>

#include "boost/scoped_ptr.hpp" //boost的智能指针
#include "gflags/gflags.h"  //负责输入参数的解析
#include "glog/logging.h" //负责log记录及输出

#include "caffe/proto/caffe.pb.h"//caffe的proto文件,里面记录了定义的各种需要的proto struct
#include "caffe/util/db.hpp" // 存储image的数据库
#include "caffe/util/io.hpp" //与文件io的一些函数

using namespace caffe;  // NOLINT(build/namespaces)

using std::max;
using std::pair;
using boost::scoped_ptr;

DEFINE_string(backend, "lmdb",
        "The backend {leveldb, lmdb} containing the images");

int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  // Print output to stderr (while still logging)
  FLAGS_alsologtostderr = 1;

#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("Compute the mean_image of a set of images given by"
        " a leveldb/lmdb\n"
        "Usage:\n"
        "    compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE]\n");

  gflags::ParseCommandLineFlags(&argc, &argv, true); //利用gflags解析输入的命令行

  if (argc < 2 || argc > 3) {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/compute_image_mean");
    return 1;
  }
//db基类提供的create函数,利用传入的backend来选择创建LMDB还是LEVELDBD对象.
  scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
  db->Open(argv[1], db::READ); //打开指定的文件,以只读方式打开
  scoped_ptr<db::Cursor> cursor(db->NewCursor()); //获得该数据库的游标Cursor。
//在caffe.proto中定义的struct,里面包含了width, height, channel, num四个维度,和放data的buffer,
//所以一个BlobProto可以存放多张图片,也就是一批图片
  BlobProto sum_blob;  
  int count = 0;
  // load first datum
  //也是在caffe.proto中定义的struct,只有三个维度width, height, channel,也就是一张图片的数据。
  //里面有一个flag来表示数据是否编码过。在io.cpp中提供了对Datum中的数据编码的函数,利用了opencv的encode函数。
  Datum datum; 
   //从一个string中获得data, 在数据库中存储的数据格式为(key, value),value就是一张图片的数据
  datum.ParseFromString(cursor->value());

  if (DecodeDatumNative(&datum)) { //io中提供的对datum decode的函数,利用了opencv的encode/decode函数
    LOG(INFO) << "Decoding Datum";
  }

  sum_blob.set_num(1); //因为sum_blob是放sum的i地方,所以只有一张图片的数据,num设置为1
  sum_blob.set_channels(datum.channels()); //图片的channel
  sum_blob.set_height(datum.height()); //图片的height
  sum_blob.set_width(datum.width()); //图片的width
  const int data_size = datum.channels() * datum.height() * datum.width(); //图片真实数据的长度,w x h x channel
  int size_in_datum = std::max<int>(datum.data().size(), datum.float_data_size());  //数据可能是uint8,也可能是float 
  for (int i = 0; i < size_in_datum; ++i) {
    sum_blob.add_data(0.); //每个元素初始化为0
  }
  LOG(INFO) << "Starting iteration";
  while (cursor->valid()) { //该cursor所在的位置valid,一直到所有数据结束
    Datum datum;
    datum.ParseFromString(cursor->value()); // 将cursor中的value,也就是一张图片放到datum中
    DecodeDatumNative(&datum); //如果数据encode过后,需要decode

    const std::string& data = datum.data(); //datum的内部data是用char数组来存储的
    size_in_datum = std::max<int>(datum.data().size(),
        datum.float_data_size());
    CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
        size_in_datum;
    if (data.size() != 0) { //如果数据是uint8
      CHECK_EQ(data.size(), size_in_datum);
      for (int i = 0; i < size_in_datum; ++i) {
        sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]); // 计算各个对象像素的值的和,以便于求均值
      }
    } else { // 数据是float
      CHECK_EQ(datum.float_data_size(), size_in_datum);
      for (int i = 0; i < size_in_datum; ++i) {
        sum_blob.set_data(i, sum_blob.data(i) +
            static_cast<float>(datum.float_data(i)));
      }
    }
    ++count;
    if (count % 10000 == 0) { // 每隔10000个图片打出一条log
      LOG(INFO) << "Processed " << count << " files.";
    }
    cursor->Next(); //cursor指向下一个数据
  }

  if (count % 10000 != 0) {
    LOG(INFO) << "Processed " << count << " files.";
  }
  for (int i = 0; i < sum_blob.data_size(); ++i) {
    sum_blob.set_data(i, sum_blob.data(i) / count); // 每个sum除以count来计算均数
  }
  // Write to disk
  if (argc == 3) {
    LOG(INFO) << "Write to " << argv[2];
    WriteProtoToBinaryFile(sum_blob, argv[2]); //io提供的函数来利用protobuf库来将数据写入文件
  }
  const int channels = sum_blob.channels();
  const int dim = sum_blob.height() * sum_blob.width();
  std::vector<float> mean_values(channels, 0.0);
  LOG(INFO) << "Number of channels: " << channels;
  for (int c = 0; c < channels; ++c) {
    for (int i = 0; i < dim; ++i) {
      mean_values[c] += sum_blob.data(dim * c + i); //这里计算每个channel的均值,也就是一张图片20x20, 400个像素的均值
    }
    LOG(INFO) << "mean_value channel [" << c << "]: " << mean_values[c] / dim;
  }
  return 0;
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值