从零开始的CAFFE代码阅读

前言

  作为一个小白,以前只是傻乎乎的用tensorflow啊,caffe之类的框架,别人给啥我用啥,别人写了自定义层我就按照教程插入一蛤。现在看来,这种拿来主义很限制威猛的我向上发展啊,所以打算借用课余时光,好好阅读一下caffe的源码,也不说一定有用,了解了解大佬们的代码逻辑,读书千遍其义自见嘛~
  源码的阅读呢打算按照两个阶段走。

  1. 理解源码里的文件的作用
      这一阶段目标是梳理caffe的实现逻辑,理解源码中关键接口的作用,例如prototxt是怎么实现网络搭建的?caffe的工厂模式是怎么实现"网络生产的"。这一阶段并不涉及网络的具体实现源码。
  2. 阅读并理解网络层的实现
      这一阶段目标是阅读caffe网络层的源码,知道他是咋实现的,咋用GPU加速了,我也就知道我该怎么写我的自定义层了。

源码阅读顺序

  其实已经有好多老哥们都写了关于caffe代码阅读的博客,但是那,光看别人的消化结果在自己个人的理解上总觉得差点意思,所以还是自己亲身实践一下.具体的caffe代码结构的梳理呀这里就不再赘述了,好多前辈们的总结都特别好。这里分享一篇很适合作为代码导读的博客,我也从这篇博客出发 ,开启我的caffe源码阅读之路啦。参考博文

  借用博客里的一张图,来简要的看一下caffe的框架:

图1
  上图描述了caffe实现网络的逻辑。通过blob,或者说是tensor,作为数据的载体在layer之间传递,可以存储数据,也可以求导。Layer则对数据进行conv, pool, BN等操作。layer在组成net,形成一个完整的神经网络。
  但是,这是咋实现的呢?道理我都懂,但是我还是不明白呀!。我们都知道在识别或者训练的时候,我们需要.prototxt文件和.caffemodel文件的路径。这时候Caffe就会把我们的这两个文件转化成网络和对应权值了。具体咋实现的呢,这里再贴一下之前博客的一张图:

Step 1. caffe.proto:对应目录 caffe-master\src\caffe\proto\caffe.proto

Step 2. Hpp文件:
   a solver.hpp  — caffe-master\include\caffe\net.hpp
   b net.hpp   — caffe-master\include\caffe\net.hpp
   c layer.hpp  — caffe-master\include\caffe\layer.hpp
   d blob.hpp  — caffe-master\include\caffe\blob.hpp
上面d,c,b,a这4个部分实际上是自底向上的结构。

Step 3. Cpp/cu文件
对应上面提到的blob、net、solver的具体实现。通常说来,caffe框架已经实现了很通用的网络结构,如果有自己的需求,添加一些新的层次即可)

Step 4.
tools文件caffe提供的工具,目录在caffe-master\tools,例如计算图像均值,调优网络,可视化等。

   这个是那个博主整理的阅读顺序,我的想法呢也是参考了他的思路,然后可能会说的更加细致一些,毕竟我底子比较差嘛~

Protobuffer是什么?caffe.proto做了啥?

   protobuffer到底是个啥?功能是啥呢?网上相关资料很多,我呢也看了之后算个一知半解吧。简单的来说,我们按照protobuffer的语法写一个.proto文件,这个文件主要是用于声明一系列对象。这个.proto文件可以按照我们的要求,使用protobuffer的编译器,自动生成.对应语言的头文件和源文件。就c++而言,他会生产.cpp和.h的文件。
在这里插入图片描述
  然后呢?为啥要大费周章的生产这些文件呢?首先,他很方便呀,就拿caffe而言,他有对应的python接口也有C++接口,但是在实际使用过程中,无论你是c++还是python,你的.prototxt和.caffemodel文件都是通用的。他是咋实现的呢?就是通过caffe.proto以及其工具文件实现这个功能的。
  说白了,protobuffer就是一个将对象序列化以及反序列化的工具,帮助我们在不同平台上读取这些数据。接下来来跑一跑google官方文档的示例来帮助理解:

protobuffer示例:

  在Linux里做这个示例其实会方便很多,在windows平台下的话,这里分享一个老哥写的protobuffer教程:Protobuffer安装教程,安装啥的就按照这个老哥上面的就可以了。接下来就开始感受一下protobuffer是怎么进行序列化和反序列化的~

第一步:
  写一个.proto文件,按照官方教程,如下:

syntax="proto2";
package tutorial;
 
message Person {
  required string name = 1;
  required int32 id = 2;
  optional string email = 3;
 
  enum PhoneType {
    MOBILE = 0;
    HOME = 1;
    WORK = 2;
  }
 
  message PhoneNumber {
    required string number = 1;
    optional PhoneType type = 2 [default = HOME];
  }
 
  repeated PhoneNumber phone = 4;
}
 
message AddressBook {
  repeated Person person = 1;
}

具体啥含义,可以去看一下protobuffer的语法,要注意的是proto2和proto3还是有比较大区别的。这一段大体的意思就是搞了两个类,一个是Person,一个是Adressbook,效果相当于通讯簿,记录不同人的个人信息的。然后捏,我们使用对应的编译器让他形成c++的头文件和源文件(当然其他语言也可以,但是caffe用的是c++,这就对标一蛤)。

生成的文件
  别看.proto那么短,他生产的文件是定义了很多方法的,序列化呀,反序列化呀,设定对象值呀,都在里头,有兴趣的口以去看一下官方文档,或者其他博主的博客,都很详细滴,这里同样不赘述。
  在得到这俩文件之后,再写一个主函数,给对象赋值,并将对象信息序列化成一个二进制文件,主函数如下:

#include <iostream>
#include <fstream>
#include <string>
#include "addressbook.pb.h"
using namespace std;
// This function fills in a Person message based on user input.
void PromptForAddress(tutorial::Person* person) {
 cout << "Enter person ID number: ";
 int id;
 cin >> id;
 person->set_id(id);
 cin.ignore(256, '\n');           // cin.ignore(a,ch)方法是从输入流(cin)中提取字符,提取的字符被忽略(ignore),不被使用。每抛弃一个字符,它都要计数和比较字符:如果计数值达到a或者被抛弃的字符是ch,则cin.ignore()函数执行终止;否则,它继续等待。它的一个常用功能就是用来清除以回车结束的输入缓冲区的内容,消除上一次输入对下一次输入的影响。比如可以这么用:cin.ignore(1024,'\n'),通常把第一个参数设置得足够大,这样实际上总是只有第二个参数'\n'起作用,所以这一句就是把回车(包括回车)之前的所以字符从输入缓冲(流)中清除出去。
 cout << "Enter name: ";
 getline(cin, *person->mutable_name());
 cout << "Enter email address (blank for none): ";
 string email;
 getline(cin, email);
 if (!email.empty()) {
  person->set_email(email);
 }
 while (true) {
  cout << "Enter a phone number (or leave blank to finish): ";
  string number;
  getline(cin, number);
  if (number.empty()) {
   break;
  }
  tutorial::Person::PhoneNumber* phone_number = person->add_phone();
  phone_number->set_number(number);
  cout << "Is this a mobile, home, or work phone? ";
  string type;
  getline(cin, type);
  if (type == "mobile") {
   phone_number->set_type(tutorial::Person::MOBILE);
  }
  else if (type == "home") {
   phone_number->set_type(tutorial::Person::HOME);
  }
  else if (type == "work") {
   phone_number->set_type(tutorial::Person::WORK);
  }
  else {
   cout << "Unknown phone type.  Using default." << endl;
  }
 }
}
// Main function:  Reads the entire address book from a file,
//   adds one person based on user input, then writes it back out to the same
//   file.
int main(int argc, char* argv[]) {
 // Verify that the version of the library that we linked against is
 // compatible with the version of the headers we compiled against.
 GOOGLE_PROTOBUF_VERIFY_VERSION;
 if (argc != 2) {
  cerr << "Usage:  " << argv[0] << " ADDRESS_BOOK_FILE" << endl;
  return -1;
 }
 tutorial::AddressBook address_book;
 {
  // Read the existing address book.
  fstream input(argv[1], ios::in | ios::binary);
  if (!input) {
   cout << argv[1] << ": File not found.  Creating a new file." << endl;
  }
  else if (!address_book.ParseFromIstream(&input)) {
   cerr << "Failed to parse address book." << endl;
   return -1;
  }
 }
 // Add an address.
 PromptForAddress(address_book.add_person());
 {
  // Write the new address book back to disk.
  fstream output(argv[1], ios::out | ios::trunc | ios::binary);
  if (!address_book.SerializeToOstream(&output)) {
   cerr << "Failed to write address book." << endl;
   return -1;
  }
 }
 // Optional:  Delete all global objects allocated by libprotobuf.
 google::protobuf::ShutdownProtobufLibrary();
 return 0;
}

  这个主函数会生成一个可执行程序,然后你输入生成后二进制文件的名字,就可以生成对应的二进制文件了,过程如下。
在这里插入图片描述

保存好的文件如下:
在这里插入图片描述
序列化过程到此结束。我们的dick man和他的信息就被存到了这个二进制文件中。这也是protobuffer的一个缺点,尽管他的效率高,压缩比大,但是,序列化完了之后可读性很差,谁也不知道里面写了啥。
  接下来我们要把这个文件反序列化,就是把他给读懂,反序列化的主函数如下:

#include <iostream>
#include <fstream>
#include <string>
#include "addressbook.pb.h"
using namespace std;
// Iterates though all people in the AddressBook and prints info about them.
void ListPeople(const tutorial::AddressBook& address_book) {
 for (int i = 0; i < address_book.person_size(); i++) {
  const tutorial::Person& person = address_book.person(i);
  cout << "Person ID: " << person.id() << endl;
  cout << "  Name: " << person.name() << endl;
  if (person.has_email()) {
   cout << "  E-mail address: " << person.email() << endl;
  }
  for (int j = 0; j < person.phone_size(); j++) {
   const tutorial::Person::PhoneNumber& phone_number = person.phone(j);
   switch (phone_number.type()) {
   case tutorial::Person::MOBILE:
    cout << "  Mobile phone #: ";
    break;
   case tutorial::Person::HOME:
    cout << "  Home phone #: ";
    break;
   case tutorial::Person::WORK:
    cout << "  Work phone #: ";
    break;
   }
   cout << phone_number.number() << endl;
  }
 }
}
// Main function:  Reads the entire address book from a file and prints all
//   the information inside.
int main(int argc, char* argv[]) {
 // Verify that the version of the library that we linked against is
 // compatible with the version of the headers we compiled against.
 GOOGLE_PROTOBUF_VERIFY_VERSION;
 if (argc != 2) {
  cerr << "Usage:  " << argv[0] << " ADDRESS_BOOK_FILE" << endl;
  return -1;
 }
 tutorial::AddressBook address_book;
 {
  // Read the existing address book.
  fstream input(argv[1], ios::in | ios::binary);
  if (!address_book.ParseFromIstream(&input)) {
   cerr << "Failed to parse address book." << endl;
   return -1;
  }
 }
 ListPeople(address_book);
 // Optional:  Delete all global objects allocated by libprotobuf.
 google::protobuf::ShutdownProtobufLibrary();
 return 0;
}

原谅我用的是同一个工程,这个主函数还是会生成一个可执行程序,你把对应二进制文件往里面一输,他就会把结果反序列化:

在这里插入图片描述
这样一来,我们就可以那么理解protobuffer在caffe里的作用了。.proto文件帮助我们生成一大堆类。由于整个文件又香又长,这里就以最为熟悉的卷积层为例,看看他到底定义了个啥:
markdown抽风了呀,复制不进去代码

这个就是caffe定义的卷积层对应的参数了,卷积核尺寸,pad尺寸,啥啥啥都在里头。caffe通过.proto这个文件,定义了每个层的类,在实际应用的时候,会将prototxt中的数据读出来,对应的赋值到.proto定义的对象中去。是不是很简单明了?明了屁哟!,其实 到现在咱们还是不知道我们的 prototxt文件是咋转化成对应的类的。还记不记得caffe的c++搭建网络的接口?

Net<float> caffe_test_net(argv[1]);

就是通过这个接口实现prototxt文件的读取的,我们看看这个接口到底干了啥,Net类的构造函数如下(net.cpp里):

Net<Dtype>::Net(const string& param_file, Phase phase,
    const int level, const vector<string>* stages) {
  NetParameter param;
  ReadNetParamsFromTextFileOrDie(param_file, &param);
  // Set phase, stages and level
  param.mutable_state()->set_phase(phase);
  if (stages != NULL) {
    for (int i = 0; i < stages->size(); i++) {
      param.mutable_state()->add_stage((*stages)[i]);
    }
  }
  param.mutable_state()->set_level(level);
  Init(param);
}

  我们可以看到ReadNetParamsFromTextFileOrDie(param_file, &param);这个函数实现了文件的读取,但是这个函数又是在哪个地方实现的呢?是在upgrade_proto.cpp这个文件里实现的。这个文件里面还有一些 便于向下兼容的函数,这里就先不关心这些,有兴趣的老哥们可以再仔细看看每个函数到底讲了啥。我们这里主要观察是哪个函数实现prototxt文件的读取。

void ReadNetParamsFromBinaryFileOrDie(const string& param_file,
                                      NetParameter* param) {
  CHECK(ReadProtoFromBinaryFile(param_file, param))
      << "Failed to parse NetParameter file: " << param_file;
  UpgradeNetAsNeeded(param_file, param);
}

发现ReadProtoFromBinaryFile这个函数是在io.cpp中定义的,然后我们再找:

bool ReadProtoFromTextFile(const char* filename, Message* proto) {
  int fd = open(filename, O_RDONLY);
  CHECK_NE(fd, -1) << "File not found: " << filename;
  FileInputStream* input = new FileInputStream(fd);
  bool success = google::protobuf::TextFormat::Parse(input, proto);
  delete input;
  close(fd);
  return success;
}

这个函数调用了google::protobuf::TextFormat::Parse(input, proto) ,将prototxt文件里的数据,传递给了NetParameter* param,这个类就是我们在caffe.proto文件里定义的东西(Markdown复制不下来呀,好痛苦,只能贴图了):

在这里插入图片描述
这段定义呢给出了咱们的blob的NWHC,网络名字,还有网络层等,最重要的就是最后的repeated LayerParameter layer,这句话的含义相当于创建了一个vector,里面的每一个类都是layer。每一个layer都有很多成员,对应卷积层之类的。

  
  
  
  
未完待续——

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值