1. 概述
在进行网络训练的时候会保存网络权值信息到文件中,用来后序部署等。在Caffe中实现这个功能是使用Snapshot()函数实现的,在Caffe中权值文件的保存形式有2中,它们是BinaryProto格式文件和hdf5格式文件,一般未指明的情况下,缺省为BinaryProto格式的形式。本篇文章介绍的存储方式是BinaryProto,hdf5形式的存储方式也是具有类似性。
存储的总体流程:
在网络训练次数达到规定的权值文件保存的次数时,就会调用Solver::Snapshot()
函数保存权值文件,这个函数会调用到net的权值文件保存函数ToProto()
,net的这个函数接下来就会调用Layer里面的函数ToProto()
,最后会调用到Blob里面的ToProto()
函数。正是这样层级的调用方式将网络的权值保存到文件中。
2. 流程分析
2.1 caffemodel文件的保存
这个文件保存的是网络中学习到的权值文件,在需要保存权值文件的时候调用的接口是:
//solver.cpp
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
CHECK(Caffe::root_solver());
string model_filename;
switch (param_.snapshot_format()) {
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: //BinaryProto文件
model_filename = SnapshotToBinaryProto();
break;
case caffe::SolverParameter_SnapshotFormat_HDF5: //hdf5文件
model_filename = SnapshotToHDF5();
break;
default:
LOG(FATAL) << "Unsupported snapshot format.";
}
SnapshotSolverState(model_filename); //保存solver_state
}
这个函数会调用SnapshotToBinaryProto()
来使用net的ToProto()
接口:
//net.cpp
template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() {
string model_filename = SnapshotFilename(".caffemodel"); //得到保存文件的路径
LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
NetParameter net_param; //存放需要保存的所有信息
net_->ToProto(&net_param, param_.snapshot_diff()); //调用net.cpp中的ToProto()函数
WriteProtoToBinaryFile(net_param, model_filename);
return model_filename;
}
在net的内部ToProto()
是逐层调用Layer的ToProto()
接口的:
//layer.hpp
// Serialize LayerParameter to protocol buffer
template <typename Dtype>
void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
param->Clear();
param->CopyFrom(layer_param_);
param->clear_blobs();
for (int i = 0; i < blobs_.size(); ++i) {
blobs_[i]->ToProto(param->add_blobs(), write_diff);
}
}
最后就调用到了Blob的ToProto()
接口,实现具体权值数据的存储:
//blob.cpp
template <>
void Blob<float>::ToProto(BlobProto* proto, bool write_diff) const {
proto->clear_shape();
for (int i = 0; i < shape_.size(); ++i) {
proto->mutable_shape()->add_dim(shape_[i]);
}
proto->clear_data();
proto->clear_diff();
const float* data_vec = cpu_data();
for (int i = 0; i < count_; ++i) {
proto->add_data(data_vec[i]);
}
if (write_diff) {
const float* diff_vec = cpu_diff();
for (int i = 0; i < count_; ++i) {
proto->add_diff(diff_vec[i]);
}
}
}
2.2 solverstate文件的保存
这个文件保存的的是在训练过程中的梯度信息,首先来看下梯度信息的更新过程吧,在某个具体的solver在进行Step()
函数的时候,没进行一次解算都会保存一次网络中的梯度信息。
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
......
while (iter_ < stop_iter) {
......
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();
......
ApplyUpdate(); //保存梯度信息
......
}
}
在ApplyUpdate()
函数中会调用ComputeUpdateValue()
函数用来保存当前迭代次数的梯度信息
template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype momentum = this->param_.momentum();
Dtype local_rate = rate * net_params_lr[param_id];
// Compute the update to history, then copy it to the parameter diff.
//保存结果算过程中的梯度信息
switch (Caffe::mode()) {
case Caffe::CPU: {
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
sgd_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
history_[param_id]->mutable_gpu_data(),
momentum, local_rate);
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
上面讲到了网络的梯度信息是保存在一个blob数组中的,那么在调用SnapshotSolverState()
函数的时候就会将梯度信息保存到文件中,以供后面网络进行恢复。
SnapshotSolverState(model_filename); //保存solver_state
这里保存的形式也是一样分为两种类型的情况,其实现为:
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
switch (this->param_.snapshot_format()) {
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
SnapshotSolverStateToBinaryProto(model_filename);
break;
case caffe::SolverParameter_SnapshotFormat_HDF5:
SnapshotSolverStateToHDF5(model_filename);
break;
default:
LOG(FATAL) << "Unsupported snapshot format.";
}
}
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
const string& model_filename) {
SolverState state;
state.set_iter(this->iter_);
state.set_learned_net(model_filename);
state.set_current_step(this->current_step_);
state.clear_history();
for (int i = 0; i < history_.size(); ++i) {
// Add history
BlobProto* history_blob = state.add_history();
history_[i]->ToProto(history_blob);
}
string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
LOG(INFO)
<< "Snapshotting solver state to binary proto file " << snapshot_filename;
WriteProtoToBinaryFile(state, snapshot_filename.c_str());
}
在上面讲完了有关权值文件保存的过程,对于权值文件加载到网络中的流程在之前的文章中有所介绍,请参考:Caffe源码,训练流程分析中的第三节。