AppendTop()和AppendBottom()

源码

// Helper for Net::Init: add a new top blob to the net.
template <typename Dtype>
void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
                           const int top_id, set<string>* available_blobs,
                           map<string, int>* blob_name_to_idx) {
  shared_ptr<LayerParameter> layer_param(
      new LayerParameter(param.layer(layer_id)));
  const string& blob_name = (layer_param->top_size() > top_id) ?
      layer_param->top(top_id) : "(automatic)";
  // Check if we are doing in-place computation
  if (blob_name_to_idx && layer_param->bottom_size() > top_id &&
      blob_name == layer_param->bottom(top_id)) {
    // In-place computation
    LOG_IF(INFO, Caffe::root_solver())
        << layer_param->name() << " -> " << blob_name << " (in-place)";
    top_vecs_[layer_id].push_back(blobs_[(*blob_name_to_idx)[blob_name]].get());
    top_id_vecs_[layer_id].push_back((*blob_name_to_idx)[blob_name]);
  } else if (blob_name_to_idx &&
             blob_name_to_idx->find(blob_name) != blob_name_to_idx->end()) {
    // If we are not doing in-place computation but have duplicated blobs,
    // raise an error.
    LOG(FATAL) << "Top blob '" << blob_name
               << "' produced by multiple sources.";
  } else {
    // Normal output.
    if (Caffe::root_solver()) {
      LOG(INFO) << layer_param->name() << " -> " << blob_name;
    }
    shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
    const int blob_id = blobs_.size();
    blobs_.push_back(blob_pointer);
    blob_names_.push_back(blob_name);
    blob_need_backward_.push_back(false);
    if (blob_name_to_idx) { (*blob_name_to_idx)[blob_name] = blob_id; }
    top_id_vecs_[layer_id].push_back(blob_id);
    top_vecs_[layer_id].push_back(blob_pointer.get());
  }
  if (available_blobs) { available_blobs->insert(blob_name); }
}

// Helper for Net::Init: add a new bottom blob to the net.
template <typename Dtype>
int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
    const int bottom_id, set<string>* available_blobs,
    map<string, int>* blob_name_to_idx) {
  const LayerParameter& layer_param = param.layer(layer_id);
  const string& blob_name = layer_param.bottom(bottom_id);
  if (available_blobs->find(blob_name) == available_blobs->end()) {
    LOG(FATAL) << "Unknown bottom blob '" << blob_name << "' (layer '"
               << layer_param.name() << "', bottom index " << bottom_id << ")";
  }
  const int blob_id = (*blob_name_to_idx)[blob_name];
  LOG_IF(INFO, Caffe::root_solver())
      << layer_names_[layer_id] << " <- " << blob_name;
  bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
  bottom_id_vecs_[layer_id].push_back(blob_id);
  available_blobs->erase(blob_name);
  bool need_backward = blob_need_backward_[blob_id];
  // Check if the backpropagation on bottom_id should be skipped
  if (layer_param.propagate_down_size() > 0) {
    need_backward = layer_param.propagate_down(bottom_id);
  }
  bottom_need_backward_[layer_id].push_back(need_backward);
  return blob_id;
}

功能分析

假定网络结构如下:

name:datan
type:Data
top:data
top:label
name:ipn
type:InnerProduct
top:ip
bottom:data
name:relun
type:ReLU
top:ip
bottom:ip
name:lossn
type:SoftmaxWithLoss
top:loss
bottom:ip
bottom:label

过程分析

datan层

blobs_blob_namesblob_need_backward
blob_pointer0datafalse
blob_pointer1labelfalse
top_id_vecs_[0]top_vecs_[0]bottom_vecs_[0]bottom_id_vecs_[0]bottom_need_backward_[1]
0blob_pointer0
1blob_pointer1

ipn层

blos_idblobs_blob_namesblob_need_backward
0blob_pointer0datafalse
1blob_pointer1labelfalse
2blob_pointer2ipfalse
top_id_vecs_[1]top_vecs_[1]bottom_vecs_[1]bottom_id_vecs_[1]bottom_need_backward_[1]
2blob_pointer20blob_pointer0propagate_down

relu层

blos_idblobs_blob_namesblob_need_backward
0blob_pointer0datafalse
1blob_pointer1labelfalse
2blob_pointer2ipfalse
top_id_vecs_[2]top_vecs_[2]bottom_vecs_[2]bottom_id_vecs_[2]bottom_need_backward_[2]
2blob_pointer22blob_pointer2propagate_down

loss层

blos_idblobs_blob_namesblob_need_backward
0blob_pointer0datafalse
1blob_pointer1labelfalse
2blob_pointer2ipfalse
3blob_pointer3lossfalse
top_id_vecs_[2]top_vecs_[2]bottom_vecs_[2]bottom_id_vecs_[2]bottom_need_backward_[2]
3blob_pointer32blob_pointer2propagate_down
1blob_pointer1propagate_down

功能总结

__blob<0>连接
__blob<2>连接
__blob<2>连接
__blob<1>连接
name:datan
type:Data
top:data
top:label
name:ipn
type:InnerProduct
top:ip
bottom:data
name:relun
type:ReLU
top:ip
bottom:ip
name:lossn
type:SoftmaxWithLoss
top:loss
bottom:ip
bottom:label
blob<3>
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值