Caffe源码,caffe::Solver<Dtype>::Snapshot运行分析

1. 概述

在进行网络训练的时候会保存网络权值信息到文件中,用来后序部署等。在Caffe中实现这个功能是使用Snapshot()函数实现的,在Caffe中权值文件的保存形式有2中,它们是BinaryProto格式文件和hdf5格式文件,一般未指明的情况下,缺省为BinaryProto格式的形式。本篇文章介绍的存储方式是BinaryProto,hdf5形式的存储方式也是具有类似性。

存储的总体流程:
在网络训练次数达到规定的权值文件保存的次数时,就会调用Solver::Snapshot()函数保存权值文件,这个函数会调用到net的权值文件保存函数ToProto(),net的这个函数接下来就会调用Layer里面的函数ToProto(),最后会调用到Blob里面的ToProto()函数。正是这样层级的调用方式将网络的权值保存到文件中。

2. 流程分析

2.1 caffemodel文件的保存

这个文件保存的是网络中学习到的权值文件,在需要保存权值文件的时候调用的接口是:

//solver.cpp
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
  CHECK(Caffe::root_solver());
  string model_filename;
  switch (param_.snapshot_format()) {
  case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: //BinaryProto文件
    model_filename = SnapshotToBinaryProto();
    break;
  case caffe::SolverParameter_SnapshotFormat_HDF5: //hdf5文件
    model_filename = SnapshotToHDF5();
    break;
  default:
    LOG(FATAL) << "Unsupported snapshot format.";
  }

  SnapshotSolverState(model_filename); //保存solver_state
}

这个函数会调用SnapshotToBinaryProto()来使用net的ToProto()接口:

//net.cpp
template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() {
  string model_filename = SnapshotFilename(".caffemodel");  //得到保存文件的路径
  LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
  NetParameter net_param;	//存放需要保存的所有信息
  net_->ToProto(&net_param, param_.snapshot_diff());	//调用net.cpp中的ToProto()函数
  WriteProtoToBinaryFile(net_param, model_filename);
  return model_filename;
}

在net的内部ToProto()是逐层调用Layer的ToProto()接口的:

//layer.hpp
// Serialize LayerParameter to protocol buffer
template <typename Dtype>
void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
  param->Clear();
  param->CopyFrom(layer_param_);
  param->clear_blobs();
  for (int i = 0; i < blobs_.size(); ++i) {
    blobs_[i]->ToProto(param->add_blobs(), write_diff);
  }
}

最后就调用到了Blob的ToProto()接口,实现具体权值数据的存储:

//blob.cpp
template <>
void Blob<float>::ToProto(BlobProto* proto, bool write_diff) const {
  proto->clear_shape();
  for (int i = 0; i < shape_.size(); ++i) {
    proto->mutable_shape()->add_dim(shape_[i]);
  }
  proto->clear_data();
  proto->clear_diff();
  const float* data_vec = cpu_data();
  for (int i = 0; i < count_; ++i) {
    proto->add_data(data_vec[i]);
  }
  if (write_diff) {
    const float* diff_vec = cpu_diff();
    for (int i = 0; i < count_; ++i) {
      proto->add_diff(diff_vec[i]);
    }
  }
}

2.2 solverstate文件的保存

这个文件保存的的是在训练过程中的梯度信息,首先来看下梯度信息的更新过程吧,在某个具体的solver在进行Step()函数的时候,没进行一次解算都会保存一次网络中的梯度信息。

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
    ......

  while (iter_ < stop_iter) {
    ......
    // accumulate the loss and gradient
    Dtype loss = 0;
    for (int i = 0; i < param_.iter_size(); ++i) {
      loss += net_->ForwardBackward();
    }
    loss /= param_.iter_size();
    ......
    ApplyUpdate(); //保存梯度信息
    ......
  }
}

ApplyUpdate()函数中会调用ComputeUpdateValue()函数用来保存当前迭代次数的梯度信息

template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  const vector<float>& net_params_lr = this->net_->params_lr();
  Dtype momentum = this->param_.momentum();
  Dtype local_rate = rate * net_params_lr[param_id];
  // Compute the update to history, then copy it to the parameter diff.
  //保存结果算过程中的梯度信息
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
              net_params[param_id]->cpu_diff(), momentum,
              history_[param_id]->mutable_cpu_data());
    caffe_copy(net_params[param_id]->count(),
        history_[param_id]->cpu_data(),
        net_params[param_id]->mutable_cpu_diff());
    break;
  }
  case Caffe::GPU: {
#ifndef CPU_ONLY
    sgd_update_gpu(net_params[param_id]->count(),
        net_params[param_id]->mutable_gpu_diff(),
        history_[param_id]->mutable_gpu_data(),
        momentum, local_rate);
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}

上面讲到了网络的梯度信息是保存在一个blob数组中的,那么在调用SnapshotSolverState()函数的时候就会将梯度信息保存到文件中,以供后面网络进行恢复。

SnapshotSolverState(model_filename); //保存solver_state

这里保存的形式也是一样分为两种类型的情况,其实现为:

template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
  switch (this->param_.snapshot_format()) {
    case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
      SnapshotSolverStateToBinaryProto(model_filename);
      break;
    case caffe::SolverParameter_SnapshotFormat_HDF5:
      SnapshotSolverStateToHDF5(model_filename);
      break;
    default:
      LOG(FATAL) << "Unsupported snapshot format.";
  }
}

template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
    const string& model_filename) {
  SolverState state;
  state.set_iter(this->iter_);
  state.set_learned_net(model_filename);
  state.set_current_step(this->current_step_);
  state.clear_history();
  for (int i = 0; i < history_.size(); ++i) {
    // Add history
    BlobProto* history_blob = state.add_history();
    history_[i]->ToProto(history_blob);
  }
  string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
  LOG(INFO)
    << "Snapshotting solver state to binary proto file " << snapshot_filename;
  WriteProtoToBinaryFile(state, snapshot_filename.c_str());
}

在上面讲完了有关权值文件保存的过程,对于权值文件加载到网络中的流程在之前的文章中有所介绍,请参考:Caffe源码,训练流程分析中的第三节。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值