Caffe源码,训练流程分析

本文深入剖析Caffe的训练流程,从入口main函数开始,详细解释SolverParameter解析、Solver和Net初始化、网络权值初始化以及训练解算过程,包括前向和反向传播、损失平滑和权值更新。通过对Caffe源码的分析,揭示了训练网络的内部工作机制。
摘要由CSDN通过智能技术生成

1. 前言

1.1 Caffe结构简单梳理

在之前的文章(Caffe源码整体结构及介绍)中介绍了Caffe中的一些重要的组件:
1)Blob 主要用来表示网络中的数据,包括训练数据,网络各层自身的参数(包括权值、偏置以及它们的梯度),网络之间传递的数据都是通过 Blob 来实现的,同时 Blob 数据也支持在 CPU 与 GPU 上存储,能够在两者之间做同步。
2)Layer 是对神经网络中各种层的一个抽象,包括我们熟知的卷积层和下采样层,还有全连接层和各种激活函数层等等。同时每种 Layer 都实现了前向传播和反向传播,并通过 Blob 来传递数据。
3)Net 是对整个网络的表示,由各种 Layer 前后连接组合而成,也是我们所构建的网络模型。
4)Solver 定义了针对 Net 网络模型的求解方法,记录网络的训练过程,保存网络模型参数,中断并恢复网络的训练过程。自定义 Solver 能够实现不同的网络求解方式。

通过去看对应模块的代码能够知道每个模块里面是怎么运作的,但是这些模块并不是单独的个体,在训练过程中是相互结合的,这里就以一个网络的训练为例子,看看整个训练的流程是什么样子的,是如何将这些单独的模块串联起来的。

cd $CAFFE_ROOT
./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt

1.2 训练的总体流程

这里主要介绍Caffe中进行训练的流程,下面是调用Caffe中的训练函数之后发生的事情的大体流程:
在这里插入图片描述

2. 训练流程的初始化

2.1 训练的入口main函数

Caffe中启动训练也是需要调用程序入口main()函数的,下面是经过简略保留关键函数调用GetBrewFunction()的main结构:PS:其中的......代表相关代码的省略

int main(int argc, char** argv) {
   
      ......
      return GetBrewFunction(caffe::string(argv[1]))();
      ......
}

上面的GetBrewFunction()函数中是通过指定的命令名称找到对应的函数指针进行回调的,对应的存储结构是std::map类型,变量为g_brew_map,也就是一个注册器。首先对于g_brew_map的定义:

typedef int (*BrewFunction)();
typedef std::map<caffe::string, BrewFunction> BrewMap;
BrewMap g_brew_map;

g_brew_map是一个key为std::string类型,value为BrewFunction类型的一个map类型的全局变量,而BrewFunction是一个函数指针类型,它指向的是参数为空,返回值为int的函数,也就是train、test、time、device_query这四个函数的类型。g_brew_map本质是一个容器。

注册器具体定义为一个宏

#define RegisterBrewFunction(func) \
namespace { \
class __Registerer_##func { \
 public: /* NOLINT */ \
  __Registerer_##func() {
    \
    g_brew_map[#func] = &func; \
  } \
}; \
__Registerer_##func g_registerer_##func; \
}

具体g_brew_map实现过程(回调过程),首先通过typedef定义函数指针 typedef int (*BrewFunction)(); 这个是用typedef定义函数指针方法。这个程序定义一个BrewFunction函数指针类型,在caffe.cpp中BrewFunction作为GetBrewFunction()函数的返回类型,可以是train(),test(),device_query(),time()这四个函数指针的其中一个。在train(),test(),中可以调用solver类的函数,从而进入到net,进入到每一层,运行整个Caffe程序。然后对每个函数注册。
下面这几个是原本caffe.cpp文件中注册了的几个函数:

RegisterBrewFunction(train)
RegisterBrewFunction(test)
RegisterBrewFunction(device_query)
RegisterBrewFunction(time)

1)train: 训练或者调整一个模型;
2)test: 在测试集上测试一个模型;
3)device_query: 打印GPU的调试信息;
4)time: 压测一个模型的执行时间,包含前向和后向传播中各层的运行时间;
如果需要,可以增加其他的方式,然后通过RegisterBrewFunction()函数注册一下即可。

GetBrewFunction()函数通过指定参数,传入”train”参数就会通过键值对调用train()函数,train函数中主要有三个方法ReadSolverParamsFromTextFileOrDie、CreateSolver、Solve,分别代表的是从文件加载训练网络的相关参数,创建solver以及求解网络的过程。

// Train / Finetune a model.
int train() {
   
  ......
  caffe::SolverParameter solver_param;
  caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//从-solver参数读取solver_param
  ......
  shared_ptr<caffe::Solver<float> >
      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));//从参数创建solver,同样采用string到函数指针的映射实现,用到了工厂模式

  if (FLAGS_snapshot.size()) {
   //迭代snapshot次后保存模型一次
    LOG(INFO) << "Resuming from " << FLAGS_snapshot;
    solver->Restore(FLAGS_snapshot.c_str());
  } else if (FLAGS_weights.size()) {
   //若采用finetuning,则拷贝weight到指定模型
    CopyLayers(solver.get(), FLAGS_weights);
  }

  if (gpus.size() > 1) {
   
    caffe::P2PSync<float> sync(solver, NULL, solver->param());
    sync.Run(gpus);
  } else {
   
    LOG(INFO) << "Starting Optimization";
    solver->Solve();//开始训练网络
  }
  LOG(INFO) << "Optimization Done.";
  return 0;
}

其中的调用caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param)解析-solver指定的solver.prototxt的文件内容到solver_param中用于后序创建solver。

2.2 SolverParameter的具体解析过程

上面代码中的SolverParameter是通过ReadSolverParamsFromTextFileOrDie()来完成解析的,这个函数的实现在/CAFFE_ROOT/src/caffe/util/upgrade_proto.cpp里,我们来看一下具体的过程:

// Read parameters from a file into a SolverParameter proto message.
 void ReadSolverParamsFromTextFileOrDie(const string& param_file,
                                        SolverParameter* param) {
   
   CHECK(ReadProtoFromTextFile(param_file, param))
       << "Failed to parse SolverParameter file: " << param_file;
   UpgradeSolverAsNeeded(param_file, param);
 }

这里调用了先后调用了两个函数,首先是ReadProtoFromTextFile,这个函数的作用是从param_file这个路径去读取solver的定义,并将文件中的内容解析存到param这个指针指向的对象,具体的实现在/CAFFE_ROOT/src/caffe/util/io.cpp的开始:

 bool ReadProtoFromTextFile(const char* filename, Message* proto) {
   
   int fd = open(filename, O_RDONLY);
   CHECK_NE(fd, -1) << "File not found: " << filename;
   FileInputStream* input = new FileInputStream(fd);
   bool success = google::protobuf::TextFormat::Parse(input, proto);
   delete input;
   close(fd);
   return success;
 }

这段代码首先是打开了文件,并且读取到了一个FileInputStream的指针中,然后通过protobuf的TextFormat::Parse函数完成了解析。

然后UpgradeSolverAsNeeded完成了新老版本caffe.proto的兼容处理:

// Check for deprecations and upgrade the SolverParameter as needed.
  bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {
   
    bool success = true;
    // Try to upgrade old style solver_type enum fields into new string type
    if (SolverNeedsTypeUpgrade(*param)) {
   
      LOG(INFO) << "Attempting to upgrade input file specified using deprecated "
                << "'solver_type' field (enum)': " << param_file;
      if (!UpgradeSolverType(param)) {
   
        success = false;
       LOG(ERROR) << "Warning: had one or more problems upgrading "
                  << "SolverType (see above).";
     } else {
   
       LOG(INFO) << "Successfully upgraded file specified using deprecated "
                 << "'solver_type' field (enum) to 'type' field (string).";
       LOG(WARNING) << "Note that future Caffe releases will only support "
                    << "'type' field (string) for a solver's type.";
     }
   }
   return success;
 }

主要的问题就是在旧版本中Solver的type是enum类型,而新版本的变为了string

2.3 Solver初始化

在上面train代码函数中使用如下代码进行solver的初始化工作,使用下面的代码用于初始化Solver和Net:

//从参数创建solver,同样采用string到函数指针的映射实现,用到了工厂模式
shared_ptr<caffe::
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值