Caffe中的Net类是如何工作的?

Net类是Caffe中Blobs,Layers,Nets三个抽象层次中最高层的抽象。Nets类负责按照网络定义文件将需要的layers和中间blobs进行实例化,并将所有的Layers组合成一个有向无环图。Nets还提供了在整个网络上进行前向传播与后向传播的接口。下面从观察Net运行的角度来解析一下Net类如何工作。

Net类数据成员概述

下面对Net类中比较重要的数据成员进行说明:

vector<shared_ptr<Layer<Dtype> > > layers_;
layers_中存放着网络的所有layers,也就是Net类的实例保存着网络定义文件中所有layer的实例

vector<shared_ptr<Blob<Dtype> > > blobs_;
blobs_中保存着网络所有的中间结果,即所有layer的输入数据(bottom blob)和输出数据(top blob)

vector<vector<Blob<Dtype>*> > bottom_vecs_;
vector<vector<Blob<Dtype>*> > top_vecs_;
bottom_vecs_保存的是各个layer的bottom blob的指针,这些指针指向blobs_中的blob。bottom_ves.size()与网络layer的数量相等,由于layer可能有多个bottom blob,所以使用vector<Blob<Dtype>*>来存放layer-wise的bottom blob。同理可以知道top_vecs的作用。

vector<shared_ptr<Blob<Dtype> > > params_;
vector<Blob<Dtype>*> learnable_params_;
上述两个数据成员存放的是指向网络参数的指针,注意,直接拥有参数的是layer,params_保存的只是网络中各个layer的参数的指针;而learnable_params_也如其名字所指,保存的是各个layer中可以被学习的参数。

Net类的实例化(一个网络的建立)

构造函数

Net类有两个构造函数,分别是Net(const NetParameter& param, const Net* root_net)Net(const string& param_file, Phase phase, const Net* root_net),前者接受NetParameter的const引用作为参数(后面参数root_net与多GPU并行训练有关,忽略掉并不影响理解),后者接受定义网络prototxt文件路径和phase作为输入。
前者直接调用Init()函数,后者将prototxt文件解析为NetPrameter后调用Init()函数。

Init()函数

Init()函数承担初始化一个网络的任务,摘取主干代码描述如下(忽略细节,大致描述过程):

for (int layer_id = 0; layer_id < param.layer_size(); ++layer_id) {//param是网络参数,layer_size()返回网络拥有的层数
    const LayerParameter& layer_param = param.layer(layer_id);//获取当前layer的参数
    layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));//根据参数实例化layer


//下面的两个for循环将此layer的bottom blob的指针和top blob的指针放入bottom_vecs_和top_vecs_,bottom blob和top blob的实例全都存放在blobs_中。相邻的两层,前一层的top blob是后一层的bottom blob,所以blobs_的同一个blob既可能是bottom blob,也可能使top blob。
    for (int bottom_id = 0; bottom_id < layer_param.bottom_size();++bottom_id) {
       const int blob_id=AppendBottom(param,layer_id,bottom_id,&available_blobs,&blob_name_to_idx);
    }

    for (int top_id = 0; top_id < num_top; ++top_id) {
       AppendTop(param, layer_id, top_id, &available_blobs, &blob_name_to_idx);
    }

//接下来的工作是将每层的parameter的指针塞进params_,尤其是learnable_params_。
   const int num_param_blobs = layers_[layer_id]->blobs().size();
   for (int param_id = 0; param_id < num_param_blobs; ++param_id) {
       AppendParam(param, layer_id, param_id);
       //AppendParam负责具体的dirtywork
    }

    }
初始化之后

经过上述过程的网络,参数都是随机产生或者指定的,如果进行预测或这fine-tuning,就需要将载入预训练的权值,Net类提供的函数CopyTrainedLayersFrom(const string& trained_file)可以实现这个过程。

网络的运行(前向传播, 反向传播和权值更新)

Net类可以提供网络级的前向前向传播、反向传播和权值更新(即在网络的所有层上有序执行前述动作)。

前向传播

与前向传播相关的函数有Forward(const vector<Blob<Dtype>*> & bottom, Dtype* loss),Forward(Dtype* loss),ForwardTo(int end)ForwardFrom(int start)ForwardFromTo(int start, int end),前面的四个函数都是对第五个函数封装,第五个函数定义如下:

template <typename Dtype>
Dtype Net<Dtype>::ForwardFromTo(int start, int end) {
CHECK_GE(start, 0);
CHECK_LT(end, layers_.size());
Dtype loss = 0;
for (int i = start; i <= end; ++i) {
// LOG(ERROR) << "Forwarding " << layer_names_[i];
Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
loss += layer_loss;
if (debug_info_) { ForwardDebugInfo(i); }
}
return loss;
}

重点语句是layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);,使用layer对应bottom blob和top blob进行前向传播。

反向传播

与前向传播一样,反向传播也有很多相关函数,但都是对BackwardFromTo(int start, int end)的封装。

  CHECK_GE(end, 0);
  CHECK_LT(start, layers_.size());
  for (int i = start; i >= end; --i) {
    if (layer_need_backward_[i]) {
      layers_[i]->Backward(top_vecs_[i], bottom_need_backward_[i], bottom_vecs_[i]);
      if (debug_info_) { BackwardDebugInfo(i); }
    }
  }
}

与前向传播相反,反向传播是从尾到头进行的。

权值更新
template <typename Dtype>
void Net<Dtype>::Update() {
  for (int i = 0; i < learnable_params_.size(); ++i) {
    learnable_params_[i]->Update();
  }
}

在训练的过程中layer的权值要根据反向传播并累积的梯度进行更新,更新的过程由Update()完成。这个函数的功能十分明确,对每个存储learnable_parms的blob调用blob的Update()函数,来更新权值。

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值