1. 前言
1.1 Caffe结构简单梳理
在之前的文章(Caffe源码整体结构及介绍)中介绍了Caffe中的一些重要的组件:
1)Blob 主要用来表示网络中的数据,包括训练数据,网络各层自身的参数(包括权值、偏置以及它们的梯度),网络之间传递的数据都是通过 Blob 来实现的,同时 Blob 数据也支持在 CPU 与 GPU 上存储,能够在两者之间做同步。
2)Layer 是对神经网络中各种层的一个抽象,包括我们熟知的卷积层和下采样层,还有全连接层和各种激活函数层等等。同时每种 Layer 都实现了前向传播和反向传播,并通过 Blob 来传递数据。
3)Net 是对整个网络的表示,由各种 Layer 前后连接组合而成,也是我们所构建的网络模型。
4)Solver 定义了针对 Net 网络模型的求解方法,记录网络的训练过程,保存网络模型参数,中断并恢复网络的训练过程。自定义 Solver 能够实现不同的网络求解方式。
通过去看对应模块的代码能够知道每个模块里面是怎么运作的,但是这些模块并不是单独的个体,在训练过程中是相互结合的,这里就以一个网络的训练为例子,看看整个训练的流程是什么样子的,是如何将这些单独的模块串联起来的。
cd $CAFFE_ROOT
./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt
1.2 训练的总体流程
这里主要介绍Caffe中进行训练的流程,下面是调用Caffe中的训练函数之后发生的事情的大体流程:
2. 训练流程的初始化
2.1 训练的入口main函数
Caffe中启动训练也是需要调用程序入口main()
函数的,下面是经过简略保留关键函数调用GetBrewFunction()的main结构:PS:其中的......
代表相关代码的省略
int main(int argc, char** argv) {
......
return GetBrewFunction(caffe::string(argv[1]))();
......
}
上面的GetBrewFunction()
函数中是通过指定的命令名称找到对应的函数指针进行回调的,对应的存储结构是std::map
类型,变量为g_brew_map
,也就是一个注册器。首先对于g_brew_map
的定义:
typedef int (*BrewFunction)();
typedef std::map<caffe::string, BrewFunction> BrewMap;
BrewMap g_brew_map;
g_brew_map
是一个key为std::string
类型,value为BrewFunction
类型的一个map类型的全局变量,而BrewFunction
是一个函数指针类型,它指向的是参数为空,返回值为int的函数,也就是train、test、time、device_query
这四个函数的类型。g_brew_map
本质是一个容器。
注册器具体定义为一个宏
#define RegisterBrewFunction(func) \
namespace { \
class __Registerer_##func { \
public: /* NOLINT */ \
__Registerer_##func() {
\
g_brew_map[#func] = &func; \
} \
}; \
__Registerer_##func g_registerer_##func; \
}”
具体g_brew_map
实现过程(回调过程),首先通过typedef
定义函数指针 typedef int (*BrewFunction)()
; 这个是用typedef
定义函数指针方法。这个程序定义一个BrewFunction
函数指针类型,在caffe.cpp中BrewFunction
作为GetBrewFunction()
函数的返回类型,可以是train(),test(),device_query(),time()
这四个函数指针的其中一个。在train(),test()
,中可以调用solver类的函数,从而进入到net,进入到每一层,运行整个Caffe程序。然后对每个函数注册。
下面这几个是原本caffe.cpp文件中注册了的几个函数:
RegisterBrewFunction(train)
RegisterBrewFunction(test)
RegisterBrewFunction(device_query)
RegisterBrewFunction(time)
1)train: 训练或者调整一个模型;
2)test: 在测试集上测试一个模型;
3)device_query: 打印GPU的调试信息;
4)time: 压测一个模型的执行时间,包含前向和后向传播中各层的运行时间;
如果需要,可以增加其他的方式,然后通过RegisterBrewFunction()函数注册一下即可。
GetBrewFunction()
函数通过指定参数,传入”train”
参数就会通过键值对调用train()函数,train函数中主要有三个方法ReadSolverParamsFromTextFileOrDie、CreateSolver、Solve,分别代表的是从文件加载训练网络的相关参数,创建solver以及求解网络的过程。
// Train / Finetune a model.
int train() {
......
caffe::SolverParameter solver_param;
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//从-solver参数读取solver_param
......
shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));//从参数创建solver,同样采用string到函数指针的映射实现,用到了工厂模式
if (FLAGS_snapshot.size()) {
//迭代snapshot次后保存模型一次
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Restore(FLAGS_snapshot.c_str());
} else if (FLAGS_weights.size()) {
//若采用finetuning,则拷贝weight到指定模型
CopyLayers(solver.get(), FLAGS_weights);
}
if (gpus.size() > 1) {
caffe::P2PSync<float> sync(solver, NULL, solver->param());
sync.Run(gpus);
} else {
LOG(INFO) << "Starting Optimization";
solver->Solve();//开始训练网络
}
LOG(INFO) << "Optimization Done.";
return 0;
}
其中的调用caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param)
解析-solver指定的solver.prototxt的文件内容到solver_param
中用于后序创建solver。
2.2 SolverParameter的具体解析过程
上面代码中的SolverParameter
是通过ReadSolverParamsFromTextFileOrDie()
来完成解析的,这个函数的实现在/CAFFE_ROOT/src/caffe/util/upgrade_proto.cpp里,我们来看一下具体的过程:
// Read parameters from a file into a SolverParameter proto message.
void ReadSolverParamsFromTextFileOrDie(const string& param_file,
SolverParameter* param) {
CHECK(ReadProtoFromTextFile(param_file, param))
<< "Failed to parse SolverParameter file: " << param_file;
UpgradeSolverAsNeeded(param_file, param);
}
这里调用了先后调用了两个函数,首先是ReadProtoFromTextFile
,这个函数的作用是从param_file
这个路径去读取solver
的定义,并将文件中的内容解析存到param
这个指针指向的对象,具体的实现在/CAFFE_ROOT/src/caffe/util/io.cpp的开始:
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
int fd = open(filename, O_RDONLY);
CHECK_NE(fd, -1) << "File not found: " << filename;
FileInputStream* input = new FileInputStream(fd);
bool success = google::protobuf::TextFormat::Parse(input, proto);
delete input;
close(fd);
return success;
}
这段代码首先是打开了文件,并且读取到了一个FileInputStream
的指针中,然后通过protobuf的TextFormat::Parse
函数完成了解析。
然后UpgradeSolverAsNeeded
完成了新老版本caffe.proto的兼容处理:
// Check for deprecations and upgrade the SolverParameter as needed.
bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {
bool success = true;
// Try to upgrade old style solver_type enum fields into new string type
if (SolverNeedsTypeUpgrade(*param)) {
LOG(INFO) << "Attempting to upgrade input file specified using deprecated "
<< "'solver_type' field (enum)': " << param_file;
if (!UpgradeSolverType(param)) {
success = false;
LOG(ERROR) << "Warning: had one or more problems upgrading "
<< "SolverType (see above).";
} else {
LOG(INFO) << "Successfully upgraded file specified using deprecated "
<< "'solver_type' field (enum) to 'type' field (string).";
LOG(WARNING) << "Note that future Caffe releases will only support "
<< "'type' field (string) for a solver's type.";
}
}
return success;
}
主要的问题就是在旧版本中Solver的type是enum
类型,而新版本的变为了string
。
2.3 Solver初始化
在上面train
代码函数中使用如下代码进行solver的初始化工作,使用下面的代码用于初始化Solver和Net:
//从参数创建solver,同样采用string到函数指针的映射实现,用到了工厂模式
shared_ptr<caffe::