CAFFE源码学习之优化方法solver

一、前言
solver就是来计算损失函数最小化的优化方法,在caffe中,提供了六种不同的优化方法:

(1)SGD;
(2)AdaGrad;
(3)AdaDelta;
(4)Adam;
(5)RMSProp;
(6)Nesterov;

优化方法需要优化的目标函数就是针对整个数据集的所有数据的平均loss,如果必要还有正则项:

L(w)=1NN1fw(x(i))+λr(w)

N 为mini-batch的大小,整个数据集被分为若干mini-batch。

(1)SGD:

随机梯度下降是相对于作用域整个数据集的批量梯度下降而言的。后者要求在每次前向后向传播中都一次性的计算所有图片的error等再更新梯度,这保证了梯度可以沿着最正确的方向下降,但是这在大量数据集的情况下是不可能的。

而最开始的随机梯度下降在每次都计算一个实例就更新一次权值。单个实例梯度下降的方向不稳定,波动极大,很难使结果走向局部最优解。

在两者的基础上,产生了mini-batch-SGD,即每次计算一个小批次的数据再更新。相比前者,其收敛速度加快,相比后者,其参数更新更加稳定。

一般的mini-batch-SGD:

Wt+1=WtαL(w)w

α 前正负号则表示梯度上升或者梯度下降。

后来《Learning representations by back-propagating errors》引进动量概念:

Vt+1=βVtαL(w)w
Wt+1=Wt+Vt+1

源码梯度 Vt+1 的计算如下:

template <typename Dtype>
__global__ void SGDUpdate(int N, Dtype* g, Dtype* h,
    Dtype momentum, Dtype local_rate) {
  CUDA_KERNEL_LOOP(i, N) {
    g[i] = h[i] = momentum*h[i] + local_rate*g[i];
  }
}

通常 β 或者momentum设为0.9。

主要作用:

动量项会加上前面的动量,如果两次动量方向相同,那么梯度加速下降,而相反时则减少速度。

我的理解
下降初期,前面的动量大概率与现在的方向相同,梯度将加速下降;

随着越来越向局部最小值趋近,当越过谷底时,由于两次动量方向相反,此时更新幅度减小,减少了在谷底的波动震荡。

后来梯度趋近于0,但是动量方向可能一致,从而有可能使得梯度跳出谷底,找到其他最小值。

每次更新都考虑之前的状态,增加了梯度下降的稳定性和速度。

括弧:拟牛顿法计算梯度的梯度,相当于加速度,所以在梯度下降的方向上可以很快到达局部最优解,但是由于hessian矩阵的计算量在 O(n3) ,不易实现。

(2)AdaGrad

Adaptive Subgradient Methods for
Online Learning and Stochastic Optimization

基本思路:针对模型的每个参数 Wi 使用不同的学习率。如何使得训练中的学习率不断调整呢?正则化!

Wt+1=Wtη(t1g(β))0.5+ϵgt

  CUDA_KERNEL_LOOP(i, N) {
    float gi = g[i];
    float hi = h[i] = h[i] + gi*gi;//平方和
    g[i] = local_rate * gi / (sqrt(hi) + delta);//更新梯度
  }

当梯度较小时,分母较小,梯度被放大;
当梯度增大时,分母变大,梯度被缩小

该方法因为直接对梯度进行适应调整,可以有效应对梯度弥散和梯度爆炸,但是如果初始化的权值过大,使得梯度下降减缓。同时由于梯度的平方和只会越来越大,梯度迅速逼近0,使得训练结束太早。

(3)AdaDelta

ADADELTA: AN ADAPTIVE LEARNING RATE METHOD

该文提出梯度下降法其实最显著还是牛顿法,但是牛顿法需要计算二阶的Hessian矩阵:

Δxt=H1tgt

由于计算复杂,所以提出一阶近似二阶的算法AdaDelta,用一阶矩阵对角线元素近似逆矩阵。

Δxt=1diag(Ht)+μgt

为了进一步解决,该方法在AdaGrad基础上提出了窗口的概念,同时将平方和换成求平均值。

窗口:鉴于AdaGrad最终梯度会不断被惩罚至0,adadelta选择只累加从t开始的前w个状态,这样可以避免过度惩罚,从而使得梯度不会被惩罚至0。

平均值:平方和换成求解前w个的均值。

Wt+1=Wt1diag(Ht)E[gtw:t]2E[g2tw:t]gt

计算 E[gtw:t] 需要存储前w个状态,所以实际上采用的是滑动平均:

E[g2]t=ρE[g2]t1+(1ρ)g2t

  CUDA_KERNEL_LOOP(i, N) {
    float gi = g[i];
    float hi = h[i] = momentum * h[i] + (1-momentum) * gi * gi;//滑动平均求解均值
    gi = gi * sqrt((h2[i] + delta) / (hi + delta));
    h2[i] = momentum * h2[i] + (1-momentum) * gi * gi;
    g[i] = local_rate * gi;
  }

(4)RMSProp

该方法在(3)基础上,计算均值时开平方。
gt=E[g2]t+ξ
最终的算法为:
ALGORITHM:ADADELTARequire:DecayRateρ,ConstantϵRequire:InitialParamx11:InitializeaccumulationvariablesE[g2]0=E[Δx2]0=02:Fort=1:TdoLoopallupdates3:ComputeGradients:gt4:AccumulateGradient:E[g2]t=ρE[g2]t1+(1ρ)g2t5:ComputeUpdate:Δx=RMS[Δx]t1RMS[g]tgt6:AccumulateUpdates:E[Δx2]t=ρE[Δx2]t1+(1ρ)Δx27:ApplyUpdate:xt+1=xt+Δxt8:EndFor

  CUDA_KERNEL_LOOP(i, N) {
    float gi = g[i];
    float hi = h[i] = rms_decay*h[i] + (1-rms_decay)*gi*gi;
    g[i] = local_rate * g[i] / (sqrt(hi) + delta);
  }

缺陷:

上述两种方法继承了AdaGrad的优点,在初期下降很快,但是后期再梯度变化不明显,使得结果反复在局部最小值附近抖动,不会轻易跳出一个局部最小值区域。

在我实际操作之后,在使用该方法之后切换成SGD之后,准确率还可以提升3%左右。

而且 ξ 并不好调,很有可能出现数值爆炸。

(5)Adam
ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION

该方法使用一阶矩和二阶矩计算单个权值的学习率,而且内存要求小。
一阶矩:

mt=β1mt1+(1β1)gt

二阶矩:

vt=β2vt1+(1beta2)g2t

由于上述矩都要偏向0的趋势,为此进一步对其进行修正:

mt^=mt1betat1
vt^=vt1betat2

最终更新结果为:

θt+1=θtηvt^+ϵmt^

Adam集RMSprop、Adadelta之大成,是目前使用最多的优化算法。原始的SGD收敛耗时太长,而且对局部最小值没有丝毫抵抗能力,所以自适应学习率的算法更优。

二、solver.hpp
1、相关参数

message SolverParameter {


  // 训练网络的名称
  optional string net = 24;
  // 网络参数,一个训练网络可能结合多个测试网络
  optional NetParameter net_param = 25;

  optional string train_net = 1; // 训练网络名
  repeated string test_net = 2; // 测试网络名
  optional NetParameter train_net_param = 21; // Inline train net params.
  repeated NetParameter test_net_param = 22; // Inline test net params.
/*一个训练网络必须选择以上三组中的一组参数*/
  // NetState用于恢复
  optional NetState train_state = 26;
  repeated NetState test_state = 27;

  // 迭代次数
  repeated int32 test_iter = 3;

  // 此处的迭代次数指每迭代n次进行一次测试,而不是每次迭代都进行测试
    optional int32 test_interval = 4 [default = 0];
  optional bool test_compute_loss = 19 [default = false];
  // 是否进行初始测试,确保内存正常
  optional bool test_initialization = 32 [default = true];
  // layer中设置的学习率乘以base_lr ,是最终的学习率
  optional float base_lr = 5; 
  // 每迭代display次,展示一次结果。
  optional int32 display = 6;
  // 计算平均损失后显示
  optional int32 average_loss = 33 [default = 1];
  optional int32 max_iter = 7; // 最大迭代次数
  // accumulate gradients over `iter_size` x `batch_size` instances
  optional int32 iter_size = 36 [default = 1];

  // 学习率衰减策略:
  //    - fixed: 保持 base_lr 不变.
  //    - step:   base_lr * gamma ^ (floor(iter / step))
  //    - exp:    base_lr * gamma ^ iter
  //    - inv:    base_lr * (1 + gamma * iter) ^ (- power)
  //    - multistep: 不规则step
  //    - poly: the effective learning rate follows a polynomial decay, to be
  //      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
  //    - sigmoid: the effective learning rate follows a sigmod decay
  //      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
  optional string lr_policy = 8;
  optional float gamma = 9; 
  optional float power = 10; 
  optional float momentum = 11; //动量
  optional float weight_decay = 12; // 权值衰减
  // 正则化策略: L1 and L2
  optional string regularization_type = 29 [default = "L2"];
  // the stepsize for learning rate policy "step"步长
  optional int32 stepsize = 13;
  // the stepsize for learning rate policy "multistep"自定义步长
  repeated int32 stepvalue = 34;

  // 梯度阈值,防止梯度爆炸
  optional float clip_gradients = 35 [default = -1];

  optional int32 snapshot = 14 [default = 0]; // The snapshot interval
  optional string snapshot_prefix = 15; // The prefix for the snapshot.
  // 是否缓存diff,如果是,将会对内存造成巨大压力
  optional bool snapshot_diff = 16 [default = false];
  enum SnapshotFormat {
    HDF5 = 0;
    BINARYPROTO = 1;
  }
  optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
  // 默认的GPU:0/1 号GPU
  enum SolverMode {
    CPU = 0;
    GPU = 1;
  }
  optional SolverMode solver_mode = 17 [default = GPU];
  // 默认为0,设备ID
  optional int32 device_id = 18 [default = 0];
  // 如果是非负数,将会提供给随机数发生器,否则使用系统时钟进行随机数生成
  optional int64 random_seed = 20 [default = -1];

  // 优化方法的类型
  optional string type = 40 [default = "SGD"];

  // 为解决 RMSProp, AdaGrad and AdaDelta and Adam的数值稳定问题,防止分母为0.
  optional float delta = 31 [default = 1e-8];
  // Adam法中二阶矩
  optional float momentum2 = 39 [default = 0.999];

  // RMSProp decay value
  // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
  optional float rms_decay = 38 [default = 0.99];

  // If true, print information about the state of the net that may help with
  // debugging learning problems.
  optional bool debug_info = 23 [default = false];

  // If false, don't save a snapshot after training finishes.
  optional bool snapshot_after_train = 28 [default = true];

  // 旧的表示方式,现在直接使用string
  enum SolverType {
    SGD = 0;
    NESTEROV = 1;
    ADAGRAD = 2;
    RMSPROP = 3;
    ADADELTA = 4;
    ADAM = 5;
  }
  // DEPRECATED: use type instead of solver_type
  optional SolverType solver_type = 30 [default = SGD];

  // Overlap compute and communication for data parallel training
  optional bool layer_wise_reduce = 41 [default = true];
}

2、成员变量


  SolverParameter param_;//优化方法相关的参数
  int iter_;//迭代次数
  int current_step_;
  shared_ptr<Net<Dtype> > net_;//指向网络的指针
  vector<shared_ptr<Net<Dtype> > > test_nets_;//测试网络,可能不止一个
  vector<Callback*> callbacks_;//回调函数,stop或者snap
  vector<Dtype> losses_;//损失
  Dtype smoothed_loss_;//平滑后的损失

  // 并行训练时的根方法,该方法保留共享层
  const Solver* const root_solver_;

  // 函数选择:stop或者snap
  ActionCallback action_request_function_;

  // early stop也是一种优化方法,主要是观察loss不在显著下降,及时停止训练。
  bool requested_early_exit_;

  DISABLE_COPY_AND_ASSIGN(Solver);

3、头文件

#ifndef CAFFE_SOLVER_HPP_
#define CAFFE_SOLVER_HPP_
#include <boost/function.hpp>
#include <string>
#include <vector>

#include "caffe/net.hpp"
#include "caffe/solver_factory.hpp"

namespace caffe {


  /*提前终止训练或者在不停止训练的情况下存快照。Ctrl-C可以停止训练,并同时保存snapshot */
  namespace SolverAction {
    enum Enum {
      NONE = 0,  // Take no special action.
      STOP = 1,  // 停止训练
      SNAPSHOT = 2  // snapshot,以备下次从存档继续训练
    };
  }

/**
 * @brief Type of 返回停止或者snapshot的枚举
 */
typedef boost::function<SolverAction::Enum()> ActionCallback;

/**
 * @brief An interface for classes that perform optimization on Net%s.
 *
 * 需要自己实现ApplyUpdate更新
 */
template <typename Dtype>
class Solver {
 public:
  explicit Solver(const SolverParameter& param,
      const Solver* root_solver = NULL);
  explicit Solver(const string& param_file, const Solver* root_solver = NULL);
  void Init(const SolverParameter& param);
  void InitTrainNet();
  void InitTestNets();

  // 用户调用该函数进行early stop 或者 snapshot
  void SetActionFunction(ActionCallback func);
  SolverAction::Enum GetRequestedAction();
  // 函数的入口 
  // 传入非零的iters可以在预先训练好的模型上继续训练
  virtual void Solve(const char* resume_file = NULL);
  inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
  void Step(int iters);
  // 从之前存储的snapshot中恢复训练
  void Restore(const char* resume_file);
  //  Solver::Snapshot 实现基本的快照 
  //需要自己实现 SnapshotSolverState()函数产生存有学习过的网络的protobuf并写到磁盘中

  void Snapshot();
  virtual ~Solver() {}
  inline const SolverParameter& param() const { return param_; }
  inline shared_ptr<Net<Dtype> > net() { return net_; }
  inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
    return test_nets_;
  }
  int iter() { return iter_; }

  // Invoked at specific points during an iteration
  class Callback {
   protected:
    virtual void on_start() = 0;
    virtual void on_gradients_ready() = 0;

    template <typename T>
    friend class Solver;
  };
  const vector<Callback*>& callbacks() const { return callbacks_; }
  void add_callback(Callback* value) {
    callbacks_.push_back(value);
  }

  void CheckSnapshotWritePermissions();
  /**
   * @brief Returns the solver type.
   */
  virtual inline const char* type() const { return ""; }

 protected:
  // Make and apply the update value for the current iteration.
  virtual void ApplyUpdate() = 0;
  string SnapshotFilename(const string extension);
  string SnapshotToBinaryProto();
  string SnapshotToHDF5();
  // The test routine
  void TestAll();
  void Test(const int test_net_id = 0);
  virtual void SnapshotSolverState(const string& model_filename) = 0;
  virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
  virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
  void DisplayOutputBlobs(const int net_id);
  void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);

};

/**
 * @brief Solver that only computes gradients, used as worker
 *        for multi-GPU training.
 */
template <typename Dtype>
class WorkerSolver : public Solver<Dtype> {
 public:
  explicit WorkerSolver(const SolverParameter& param,
      const Solver<Dtype>* root_solver = NULL)
      : Solver<Dtype>(param, root_solver) {}

 protected:
  void ApplyUpdate() {}
  void SnapshotSolverState(const string& model_filename) {
    LOG(FATAL) << "Should not be called on worker solver.";
  }
  void RestoreSolverStateFromBinaryProto(const string& state_file) {
  }
};
    LOG(FATAL) << "Should not be called on worker solver.";
  }
  void RestoreSolverStateFromHDF5(const string& state_file) {
    LOG(FATAL) << "Should not be called on worker solver.";

}  // namespace caffe

#endif  // CAFFE_SOLVER_HPP_

4、实现

namespace caffe {

template<typename Dtype>
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
  action_request_function_ = func;//返回枚举类型的函数的指针
}

template<typename Dtype>
SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
  if (action_request_function_) {

    return action_request_function_();//返回枚举类型,STOP或者SNAPSHOT
  }
  return SolverAction::NONE;//否则一句话也不说
}

template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
    : net_(), callbacks_(), root_solver_(root_solver),
      requested_early_exit_(false) {
  Init(param);//构造函数实际调用Init()
}

template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
    : net_(), callbacks_(), root_solver_(root_solver),
      requested_early_exit_(false) {
  SolverParameter param;
  ReadSolverParamsFromTextFileOrDie(param_file, &param);//从text文件读取参数,转换为protobuf
  Init(param);
}

template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
  CHECK(Caffe::root_solver() || root_solver_)
      << "root_solver_ needs to be set for all non-root solvers";
  LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
    << std::endl << param.DebugString();
  param_ = param;
  CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
  CheckSnapshotWritePermissions();
  if (Caffe::root_solver() && param_.random_seed() >= 0) {
    Caffe::set_random_seed(param_.random_seed());//设置随机数种子
  }
  // Scaffolding code
  InitTrainNet();//初始化训练网络
  if (Caffe::root_solver()) {
    InitTestNets();
    LOG(INFO) << "Solver scaffolding done.";
  }
  iter_ = 0;
  current_step_ = 0;
}

template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
  const int num_train_nets = param_.has_net() + param_.has_net_param() +
      param_.has_train_net() + param_.has_train_net_param();//统计总共多少个训练网络
  const string& field_names = "net, net_param, train_net, train_net_param";
  CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "//训练网络必须一种方式指定而不是多个
      << "using one of these fields: " << field_names;
  CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
      << "one of these fields specifying a train_net: " << field_names;
  NetParameter net_param;//网络参数
  //以下就是将其他格式的网络参数都转换为 NetParameter
  if (param_.has_train_net_param()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net specified in train_net_param.";
    net_param.CopyFrom(param_.train_net_param());
  } else if (param_.has_train_net()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net from train_net file: " << param_.train_net();
    ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
  }
  if (param_.has_net_param()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net specified in net_param.";
    net_param.CopyFrom(param_.net_param());
  }
  if (param_.has_net()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net from net file: " << param_.net();
    ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
  }
  // Set the correct NetState.  We start with the solver defaults (lowest
  // precedence); then, merge in any NetState specified by the net_param itself;
  // finally, merge in any NetState specified by the train_state (highest
  // precedence).
  NetState net_state;
  net_state.set_phase(TRAIN);
  net_state.MergeFrom(net_param.state());
  net_state.MergeFrom(param_.train_state());
  net_param.mutable_state()->CopyFrom(net_state);
  if (Caffe::root_solver()) {
    net_.reset(new Net<Dtype>(net_param));
  } else {
    net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
  }
}

template <typename Dtype>
void Solver<Dtype>::InitTestNets() {
  CHECK(Caffe::root_solver());
  const bool has_net_param = param_.has_net_param();
  const bool has_net_file = param_.has_net();
  const int num_generic_nets = has_net_param + has_net_file;
  CHECK_LE(num_generic_nets, 1)
      << "Both net_param and net_file may not be specified.";
  const int num_test_net_params = param_.test_net_param_size();
  const int num_test_net_files = param_.test_net_size();
  const int num_test_nets = num_test_net_params + num_test_net_files;
  if (num_generic_nets) {
      CHECK_GE(param_.test_iter_size(), num_test_nets)
          << "test_iter must be specified for each test network.";
  } else {
      CHECK_EQ(param_.test_iter_size(), num_test_nets)
          << "test_iter must be specified for each test network.";
  }
  // If we have a generic net (specified by net or net_param, rather than
  // test_net or test_net_param), we may have an unlimited number of actual
  // test networks -- the actual number is given by the number of remaining
  // test_iters after any test nets specified by test_net_param and/or test_net
  // are evaluated.
  const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
  const int num_test_net_instances = num_test_nets + num_generic_net_instances;
  if (param_.test_state_size()) {
    CHECK_EQ(param_.test_state_size(), num_test_net_instances)
        << "test_state must be unspecified or specified once per test net.";
  }
  if (num_test_net_instances) {
    CHECK_GT(param_.test_interval(), 0);
  }
  int test_net_id = 0;
  vector<string> sources(num_test_net_instances);
  vector<NetParameter> net_params(num_test_net_instances);
  for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
      sources[test_net_id] = "test_net_param";
      net_params[test_net_id].CopyFrom(param_.test_net_param(i));
  }
  for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
      sources[test_net_id] = "test_net file: " + param_.test_net(i);
      ReadNetParamsFromTextFileOrDie(param_.test_net(i),
          &net_params[test_net_id]);
  }
  const int remaining_test_nets = param_.test_iter_size() - test_net_id;
  if (has_net_param) {
    for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
      sources[test_net_id] = "net_param";
      net_params[test_net_id].CopyFrom(param_.net_param());
    }
  }
  if (has_net_file) {
    for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
      sources[test_net_id] = "net file: " + param_.net();
      ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
    }
  }
  test_nets_.resize(num_test_net_instances);
  for (int i = 0; i < num_test_net_instances; ++i) {
    // Set the correct NetState.  We start with the solver defaults (lowest
    // precedence); then, merge in any NetState specified by the net_param
    // itself; finally, merge in any NetState specified by the test_state
    // (highest precedence).
    NetState net_state;
    net_state.set_phase(TEST);
    net_state.MergeFrom(net_params[i].state());
    if (param_.test_state_size()) {
      net_state.MergeFrom(param_.test_state(i));
    }
    net_params[i].mutable_state()->CopyFrom(net_state);
    LOG(INFO)
        << "Creating test net (#" << i << ") specified by " << sources[i];
    if (Caffe::root_solver()) {
      test_nets_[i].reset(new Net<Dtype>(net_params[i]));
    } else {
      test_nets_[i].reset(new Net<Dtype>(net_params[i],
          root_solver_->test_nets_[i].get()));
    }
    test_nets_[i]->set_debug_info(param_.debug_info());
  }
}

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
  const int start_iter = iter_;//起始迭代次数,不为零则是从之前缓存的训练中恢复
  const int stop_iter = iter_ + iters;//起始加迭代次数就是结束时的迭代次数
  int average_loss = this->param_.average_loss();
  losses_.clear();
  smoothed_loss_ = 0;

  while (iter_ < stop_iter) {
    // 清空参数
    net_->ClearParamDiffs();
    if (param_.test_interval() && iter_ % param_.test_interval() == 0
        && (iter_ > 0 || param_.test_initialization())
        && Caffe::root_solver()) {
      TestAll();
      if (requested_early_exit_) {
        // Break out of the while loop because stop was requested while testing.
        break;
      }
    }

    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_start();
    }
    const bool display = param_.display() && iter_ % param_.display() == 0;
    net_->set_debug_info(display && param_.debug_info());
    // 累加loss和梯度
    Dtype loss = 0;
    for (int i = 0; i < param_.iter_size(); ++i) {
      loss += net_->ForwardBackward();
    }
    loss /= param_.iter_size();//求平均loss
    // 滑动平均更新loss
    UpdateSmoothedLoss(loss, start_iter, average_loss);
    if (display) {
      LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
          << ", loss = " << smoothed_loss_;
      const vector<Blob<Dtype>*>& result = net_->output_blobs();
      int score_index = 0;
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();
        const string& output_name =
            net_->blob_names()[net_->output_blob_indices()[j]];
        const Dtype loss_weight =
            net_->blob_loss_weights()[net_->output_blob_indices()[j]];
        for (int k = 0; k < result[j]->count(); ++k) {
          ostringstream loss_msg_stream;
          if (loss_weight) {
            loss_msg_stream << " (* " << loss_weight
                            << " = " << loss_weight * result_vec[k] << " loss)";
          }
          LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
              << score_index++ << ": " << output_name << " = "
              << result_vec[k] << loss_msg_stream.str();
        }
      }
    }
    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_gradients_ready();
    }
    ApplyUpdate();//需要自己实现进行参数更新

    // Increment the internal iter_ counter -- its value should always indicate
    // the number of times the weights have been updated.
    ++iter_;

    SolverAction::Enum request = GetRequestedAction();

    // early stop或者snap
    if ((param_.snapshot()
         && iter_ % param_.snapshot() == 0
         && Caffe::root_solver()) ||
         (request == SolverAction::SNAPSHOT)) {
      Snapshot();
    }
    if (SolverAction::STOP == request) {
      requested_early_exit_ = true;

      break;
    }
  }
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
  CHECK(Caffe::root_solver());
  LOG(INFO) << "Solving " << net_->name();
  LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

  // Initialize to false every time we start solving.
  requested_early_exit_ = false;

  if (resume_file) {
    LOG(INFO) << "Restoring previous solver status from " << resume_file;
    Restore(resume_file);
  }

  // For a network that is trained by the solver, no bottom or top vecs
  // should be given, and we will just provide dummy vecs.
  int start_iter = iter_;
  Step(param_.max_iter() - iter_);
  // If we haven't already, save a snapshot after optimization, unless
  // overridden by setting snapshot_after_train := false
  if (param_.snapshot_after_train()
      && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
    Snapshot();
  }
  if (requested_early_exit_) {
    LOG(INFO) << "Optimization stopped early.";
    return;
  }
  // After the optimization is done, run an additional train and test pass to
  // display the train and test loss/outputs if appropriate (based on the
  // display and test_interval settings, respectively).  Unlike in the rest of
  // training, for the train net we only run a forward pass as we've already
  // updated the parameters "max_iter" times -- this final pass is only done to
  // display the loss, which is computed in the forward pass.
  if (param_.display() && iter_ % param_.display() == 0) {
    int average_loss = this->param_.average_loss();
    Dtype loss;
    net_->Forward(&loss);

    UpdateSmoothedLoss(loss, start_iter, average_loss);

    LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
  }
  if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
    TestAll();
  }
  LOG(INFO) << "Optimization Done.";
}

template <typename Dtype>
void Solver<Dtype>::TestAll() {
  for (int test_net_id = 0;
       test_net_id < test_nets_.size() && !requested_early_exit_;
       ++test_net_id) {
    Test(test_net_id);
  }
}

template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
  CHECK(Caffe::root_solver());
  LOG(INFO) << "Iteration " << iter_
            << ", Testing net (#" << test_net_id << ")";
  CHECK_NOTNULL(test_nets_[test_net_id].get())->
      ShareTrainedLayersWith(net_.get());
  vector<Dtype> test_score;
  vector<int> test_score_output_id;
  const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
  Dtype loss = 0;
  for (int i = 0; i < param_.test_iter(test_net_id); ++i) {//对于每一次测试的迭代
    SolverAction::Enum request = GetRequestedAction();
    // 先检查是否有中断操作
    while (request != SolverAction::NONE) {
        if (SolverAction::SNAPSHOT == request) {
          Snapshot();
        } else if (SolverAction::STOP == request) {
          requested_early_exit_ = true;
        }
        request = GetRequestedAction();
    }
    if (requested_early_exit_) {
      // break out of test loop.
      break;
    }

    Dtype iter_loss;
    const vector<Blob<Dtype>*>& result =
        test_net->Forward(&iter_loss);//调用先前向计算:计算loss,存储在iter_loss中
    if (param_.test_compute_loss()) {
      loss += iter_loss;//累加loss
    }
    if (i == 0) {//特别是第一次测试的时候
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();//取出每一个输出blob
        for (int k = 0; k < result[j]->count(); ++k) {
          test_score.push_back(result_vec[k]);//把blob中的每一个数据点降维一维存入到test_score
          test_score_output_id.push_back(j);//相应的索引同时存入test_score_output_id
        }
      }
    } else {
      int idx = 0;
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();
        for (int k = 0; k < result[j]->count(); ++k) {
          test_score[idx++] += result_vec[k];//第一次之后,测试的数据点不断累加
        }
      }
    }
  }
  if (requested_early_exit_) {//是否early stop
    LOG(INFO)     << "Test interrupted.";
    return;
  }
  if (param_.test_compute_loss()) {//是否计算loss
    loss /= param_.test_iter(test_net_id);//累加的loss求平均
    LOG(INFO) << "Test loss: " << loss;
  }
  for (int i = 0; i < test_score.size(); ++i) {//display loss
    const int output_blob_index =
        test_net->output_blob_indices()[test_score_output_id[i]];
    const string& output_name = test_net->blob_names()[output_blob_index];
    const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
    ostringstream loss_msg_stream;
    const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
    if (loss_weight) {
      loss_msg_stream << " (* " << loss_weight
                      << " = " << loss_weight * mean_score << " loss)";
    }
    LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
              << mean_score << loss_msg_stream.str();
  }
}

template <typename Dtype>
void Solver<Dtype>::Snapshot() {//输出当前的网络状态到文件中,以备下次恢复继续训练
  CHECK(Caffe::root_solver());
  string model_filename;
  switch (param_.snapshot_format()) {
  case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
    model_filename = SnapshotToBinaryProto();
    break;
  case caffe::SolverParameter_SnapshotFormat_HDF5:
    model_filename = SnapshotToHDF5();
    break;
  default:
    LOG(FATAL) << "Unsupported snapshot format.";
  }

  SnapshotSolverState(model_filename);
}

template <typename Dtype>
void Solver<Dtype>::CheckSnapshotWritePermissions() {
  if (Caffe::root_solver() && param_.snapshot()) {
    CHECK(param_.has_snapshot_prefix())
        << "In solver params, snapshot is specified but snapshot_prefix is not";
    string probe_filename = SnapshotFilename(".tempfile");
    std::ofstream probe_ofs(probe_filename.c_str());
    if (probe_ofs.good()) {
      probe_ofs.close();
      std::remove(probe_filename.c_str());
    } else {
      LOG(FATAL) << "Cannot write to snapshot prefix '"
          << param_.snapshot_prefix() << "'.  Make sure "
          << "that the directory exists and is writeable.";
    }
  }
}

template <typename Dtype>
string Solver<Dtype>::SnapshotFilename(const string extension) {
  return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
    + extension;
}

template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() {
  string model_filename = SnapshotFilename(".caffemodel");
  LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
  NetParameter net_param;
  net_->ToProto(&net_param, param_.snapshot_diff());
  WriteProtoToBinaryFile(net_param, model_filename);
  return model_filename;
}

template <typename Dtype>
string Solver<Dtype>::SnapshotToHDF5() {
  string model_filename = SnapshotFilename(".caffemodel.h5");
  LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
  net_->ToHDF5(model_filename, param_.snapshot_diff());
  return model_filename;
}

template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
  CHECK(Caffe::root_solver());//从文件中恢复出状态
  string state_filename(state_file);
  if (state_filename.size() >= 3 &&
      state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
    RestoreSolverStateFromHDF5(state_filename);
  } else {
    RestoreSolverStateFromBinaryProto(state_filename);
  }
}

template <typename Dtype>
void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
    int average_loss) {
  if (losses_.size() < average_loss) {//求得loss个数小于参数设置的个数
    losses_.push_back(loss);
    int size = losses_.size();
    smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;//滑动平均求loss
  } else {//超出部分
    int idx = (iter_ - start_iter) % average_loss;
    smoothed_loss_ += (loss - losses_[idx]) / average_loss;
    losses_[idx] = loss;
  }
}

INSTANTIATE_CLASS(Solver);

}  // namespace caffe

三、sgd_solver

sgd_solver是带动量的SGD的实现,先看头文件。

SGDSolver继承自Solver;
而NesterovSolver、AdaGradSolver、RMSPropSolver 、AdaDeltaSolver、AdamSolver都继承自SGDSolver。

他们的构造函数都是先构造基类,然后调用PrexxSolve进行继承类的构造。

namespace caffe {


template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
 public:
  explicit SGDSolver(const SolverParameter& param)
      : Solver<Dtype>(param) { PreSolve(); }
  explicit SGDSolver(const string& param_file)
      : Solver<Dtype>(param_file) { PreSolve(); }
  virtual inline const char* type() const { return "SGD"; }

  const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

 protected:
  void PreSolve();
  Dtype GetLearningRate();
  virtual void ApplyUpdate();
  virtual void Normalize(int param_id);
  virtual void Regularize(int param_id);
  virtual void ComputeUpdateValue(int param_id, Dtype rate);
  virtual void ClipGradients();
  virtual void SnapshotSolverState(const string& model_filename);
  virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
  virtual void SnapshotSolverStateToHDF5(const string& model_filename);
  virtual void RestoreSolverStateFromHDF5(const string& state_file);
  virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
  // history_保存以前的动量数据.
  // update_保存更新的动量数据
  // temp_保存其他计算梯度需要的信息
  vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;

  DISABLE_COPY_AND_ASSIGN(SGDSolver);
};

template <typename Dtype>
class NesterovSolver : public SGDSolver<Dtype> {
 public:
  explicit NesterovSolver(const SolverParameter& param)
      : SGDSolver<Dtype>(param) {}
  explicit NesterovSolver(const string& param_file)
      : SGDSolver<Dtype>(param_file) {}
  virtual inline const char* type() const { return "Nesterov"; }

 protected:
  virtual void ComputeUpdateValue(int param_id, Dtype rate);

  DISABLE_COPY_AND_ASSIGN(NesterovSolver);
};

template <typename Dtype>
class AdaGradSolver : public SGDSolver<Dtype> {
 public:
  explicit AdaGradSolver(const SolverParameter& param)
      : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
  explicit AdaGradSolver(const string& param_file)
      : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
  virtual inline const char* type() const { return "AdaGrad"; }

 protected:
  virtual void ComputeUpdateValue(int param_id, Dtype rate);
  void constructor_sanity_check() {
    CHECK_EQ(0, this->param_.momentum())
        << "Momentum cannot be used with AdaGrad.";
  }

  DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};


template <typename Dtype>
class RMSPropSolver : public SGDSolver<Dtype> {
 public:
  explicit RMSPropSolver(const SolverParameter& param)
      : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
  explicit RMSPropSolver(const string& param_file)
      : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
  virtual inline const char* type() const { return "RMSProp"; }

 protected:
  virtual void ComputeUpdateValue(int param_id, Dtype rate);
  void constructor_sanity_check() {
    CHECK_EQ(0, this->param_.momentum())
        << "Momentum cannot be used with RMSProp.";
    CHECK_GE(this->param_.rms_decay(), 0)
        << "rms_decay should lie between 0 and 1.";
    CHECK_LT(this->param_.rms_decay(), 1)
        << "rms_decay should lie between 0 and 1.";
  }

  DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
};

template <typename Dtype>
class AdaDeltaSolver : public SGDSolver<Dtype> {
 public:
  explicit AdaDeltaSolver(const SolverParameter& param)
      : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
  explicit AdaDeltaSolver(const string& param_file)
      : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
  virtual inline const char* type() const { return "AdaDelta"; }

 protected:
  void AdaDeltaPreSolve();
  virtual void ComputeUpdateValue(int param_id, Dtype rate);

  DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
};

/**
 * @brief AdamSolver, an algorithm for first-order gradient-based optimization
 *        of stochastic objective functions, based on adaptive estimates of
 *        lower-order moments. Described in [1].
 *
 * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization."
 *     arXiv preprint arXiv:1412.6980v8 (2014).
 */
template <typename Dtype>
class AdamSolver : public SGDSolver<Dtype> {
 public:
  explicit AdamSolver(const SolverParameter& param)
      : SGDSolver<Dtype>(param) { AdamPreSolve();}
  explicit AdamSolver(const string& param_file)
      : SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
  virtual inline const char* type() const { return "Adam"; }

 protected:
  void AdamPreSolve();
  virtual void ComputeUpdateValue(int param_id, Dtype rate);

  DISABLE_COPY_AND_ASSIGN(AdamSolver);
};

SGD实现:
1、PreSolve构造本类:

template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();//从protobuf读入学习参数
  history_.clear();
  update_.clear();
  temp_.clear();
  for (int i = 0; i < net_params.size(); ++i) {
    const vector<int>& shape = net_params[i]->shape();
    history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));//根据protobuf形状参数初始化history_
    update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));//根据protobuf形状参数初始化update_
    temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));//根据protobuf形状参数初始化temp_
  }
}

2、学习率策略:

template <typename Dtype>
Dtype SGDSolver<Dtype>::GetLearningRate() {
  Dtype rate;
  const string& lr_policy = this->param_.lr_policy();//prototxt中预先设置好的学习率变化策略
  if (lr_policy == "fixed") {//固定不变
    rate = this->param_.base_lr();
  } else if (lr_policy == "step") {//每个stepsize,学习率下降
    this->current_step_ = this->iter_ / this->param_.stepsize();//base_lr * gamma ^ (floor(iter / step))
    rate = this->param_.base_lr() *
        pow(this->param_.gamma(), this->current_step_);
  } else if (lr_policy == "exp") {//base_lr * gamma ^ iter
    rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
  } else if (lr_policy == "inv") {//base_lr * (1 + gamma * iter) ^ (- power)
    rate = this->param_.base_lr() *
        pow(Dtype(1) + this->param_.gamma() * this->iter_,
            - this->param_.power());
  } else if (lr_policy == "multistep") {//跟step一样,只是步长可以自定义
    if (this->current_step_ < this->param_.stepvalue_size() &&
          this->iter_ >= this->param_.stepvalue(this->current_step_)) {
      this->current_step_++;
      LOG(INFO) << "MultiStep Status: Iteration " <<
      this->iter_ << ", step = " << this->current_step_;
    }
    rate = this->param_.base_lr() *
        pow(this->param_.gamma(), this->current_step_);
  } else if (lr_policy == "poly") {//base_lr (1 - iter/max_iter) ^ (power)
    rate = this->param_.base_lr() * pow(Dtype(1.) -
        (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
        this->param_.power());
  } else if (lr_policy == "sigmoid") {// base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
    rate = this->param_.base_lr() * (Dtype(1.) /
        (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
          Dtype(this->param_.stepsize())))));
  } else {
    LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
  }
  return rate;
}

3、梯度阈值
每次更新后都判断梯度是否超过了阈值,超过了就直接等于阈值,这样有助于避免梯度爆炸

template <typename Dtype>
void SGDSolver<Dtype>::ClipGradients() {
  const Dtype clip_gradients = this->param_.clip_gradients();//
  if (clip_gradients < 0) { return; }
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  Dtype sumsq_diff = 0;
  for (int i = 0; i < net_params.size(); ++i) {
    sumsq_diff += net_params[i]->sumsq_diff();//所有权重梯度的平方和sumsq_diff
  }
  const Dtype l2norm_diff = std::sqrt(sumsq_diff);
  if (l2norm_diff > clip_gradients) {//如果sumsq_diff > clip_gradient,则求缩放因子scale_factor = clip_gradient / sumsq_diff。这个scale_factor在(0,1)之间。如果权重梯度的平方和sumsq_diff越大,那缩放因子将越小。


    Dtype scale_factor = clip_gradients / l2norm_diff;
    LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
        << l2norm_diff << " > " << clip_gradients << ") "
        << "by scale factor " << scale_factor;
    for (int i = 0; i < net_params.size(); ++i) {
      net_params[i]->scale_diff(scale_factor);
    }
  }
}

4、ApplyUpdate()

template <typename Dtype>
void SGDSolver<Dtype>::ApplyUpdate() {
  CHECK(Caffe::root_solver());
  Dtype rate = GetLearningRate();//学习率变换策略
  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
  }
  ClipGradients();//梯度阈值
  for (int param_id = 0; param_id < this->net_->learnable_params().size();
       ++param_id) {
    Normalize(param_id);//归一化
    Regularize(param_id);//正则化
    ComputeUpdateValue(param_id, rate);//实际更新
  }
  this->net_->Update();//网络参数更新
}

5、normalize归一化

template <typename Dtype>
void SGDSolver<Dtype>::Normalize(int param_id) {
  if (this->param_.iter_size() == 1) { return; }
  // Scale gradient to counterbalance accumulation.
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    caffe_scal(net_params[param_id]->count(), accum_normalization,
        net_params[param_id]->mutable_cpu_diff());
    break;//Dtype(1.) / this->param_.iter_size()*net_params[param_id]->mutable_cpu_diff()
  }
  case Caffe::GPU: {
#ifndef CPU_ONLY
    caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
        net_params[param_id]->mutable_gpu_diff());
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}

6、regularize正则化

L(w)=1NN1fw(x(i))+λr(w)

L1正则化:
L(w)=1NN1fw(x(i))+λnw|w|

求导: L(w)w=1NN1fw(x(i))+λnsgn(w)

L2正则化:

L(w)=1NN1fw(x(i))+λ2nww2

求导: L(w)w=1NN1fw(x(i))+λnw

template <typename Dtype>
void SGDSolver<Dtype>::Regularize(int param_id) {
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  const vector<float>& net_params_weight_decay =
      this->net_->params_weight_decay();//权值衰减
  Dtype weight_decay = this->param_.weight_decay();
  string regularization_type = this->param_.regularization_type();//正则化类型L1,L2
  Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    if (local_decay) {
      if (regularization_type == "L2") {
        // add weight decay
        caffe_axpy(net_params[param_id]->count(),
            local_decay,
            net_params[param_id]->cpu_data(),
            net_params[param_id]->mutable_cpu_diff());
      } else if (regularization_type == "L1") {
        caffe_cpu_sign(net_params[param_id]->count(),
            net_params[param_id]->cpu_data(),
            temp_[param_id]->mutable_cpu_data());
        caffe_axpy(net_params[param_id]->count(),
            local_decay,
            temp_[param_id]->cpu_data(),
            net_params[param_id]->mutable_cpu_diff());
      } else {
        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
      }
    }
    break;
  }
  case Caffe::GPU: {
#ifndef CPU_ONLY
    if (local_decay) {
      if (regularization_type == "L2") {
        // add weight decay
        caffe_gpu_axpy(net_params[param_id]->count(),
            local_decay,
            net_params[param_id]->gpu_data(),
            net_params[param_id]->mutable_gpu_diff());
      } else if (regularization_type == "L1") {
        caffe_gpu_sign(net_params[param_id]->count(),
            net_params[param_id]->gpu_data(),
            temp_[param_id]->mutable_gpu_data());
        caffe_gpu_axpy(net_params[param_id]->count(),
            local_decay,
            temp_[param_id]->gpu_data(),
            net_params[param_id]->mutable_gpu_diff());
      } else {
        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
      }
    }
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}

7、ComputeUpdateValue
调用

template <typename Dtype>
__global__ void SGDUpdate(int N, Dtype* g, Dtype* h,
    Dtype momentum, Dtype local_rate) {
  CUDA_KERNEL_LOOP(i, N) {
    g[i] = h[i] = momentum*h[i] + local_rate*g[i];
  }
}
template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();//权值
  const vector<float>& net_params_lr = this->net_->params_lr();//学习率
  Dtype momentum = this->param_.momentum();//动量,一般为0.9
  Dtype local_rate = rate * net_params_lr[param_id];
  // Compute the update to history, then copy it to the parameter diff.
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
              net_params[param_id]->cpu_diff(), momentum,
              history_[param_id]->mutable_cpu_data());
    caffe_copy(net_params[param_id]->count(),
        history_[param_id]->cpu_data(),
        net_params[param_id]->mutable_cpu_diff());
    break;
  }
  case Caffe::GPU: {
#ifndef CPU_ONLY
    sgd_update_gpu(net_params[param_id]->count(),
        net_params[param_id]->mutable_gpu_diff(),
        history_[param_id]->mutable_gpu_data(),
        momentum, local_rate);
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值