mxnet代码解析之computation graph

mxnet的前后向计算是以图模型构建的,它有两个重要的类结构,一个是Symbol,另一个是StaticGraph。Symbol和StaticGraph可以相互转换,Symbol提供了灵活的方式来合成节点,StaticGraph则包含了实际的配置。

Symbol

symbol类表示网络结构动态生成的符号计算图,symbol基本结构如下:

class Symbol {
protected:
  struct Node;
  struct DataEntry {
    std::shared_ptr<Node> source;
    //source输出的index
    uint32_t index;
    ......
  };
  std::vector<DataEntry> heads_;
  ......
  void Compose(const std::vector<Symbol>& args,const std::string& name);
  void Compose(const std::unordered_map<std::string, Symbol>& kwargs,
               const std::string& name);
  Symbol Grad(const std::vector<std::string>& wrt) const;
}

其中head_表示symbol的输出节点, 而DataEntry类与Node类的定义存在交叉,Node类的定义如下:

struct Symbol::Node {
  std::unique_ptr<OperatorProperty> op;
  std::string name;
  std::vector<DataEntry> inputs;
  //当前节点的源节点
  std::shared_ptr<Symbol::Node> backward_source_node;
  std::unique_ptr<std::map<std::string, std::string> > attr;
  ......
  ......
  template<typename FVisit>
    inline void Symbol::DFSVisit(FVisit fvisit) const {
      typedef const std::shared_ptr<Node>* GNode;
      std::vector<GNode> head_nodes(heads_.size());
      std::transform(heads_.begin(), heads_.end(), head_nodes.begin(),
                     [](const DataEntry& e)->GNode {
                       return &e.source;
                     });
      graph::PostOrderDFSVisit<GNode, Node*>(
          head_nodes,
          [fvisit](GNode n) { fvisit(*n); },  // FVisit
          [](GNode n)->Node* { return n->get(); },  // HashFunc
          [](GNode n)->uint32_t { return (*n)->inputs.size() +
                static_cast<int>((*n)->is_backward()); },  // InDegree
          [](GNode n, uint32_t index)->GNode {  // GetInput
            if (index < (*n)->inputs.size()) {
              return &(*n)->inputs.at(index).source;
            } else {
              return &(*n)->backward_source_node;
            }
          });
    }
  ......
  ......
}

Node有三种类型:普通节点、操作属性、可变的。可变的节点表示操作为空的参数节点。
Symbol::Compose用于合成节点,以提供完整的位置参数为例:

void Symbol::Compose(const std::vector<Symbol>& args,
                     const std::string& name) {
  // CHECK_EQ(NumOutputs(), 1) << "Only composition of value function is supported currently";
  CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed";
  heads_[0].source->name = name;
  for (size_t i = 0; i < args.size(); ++i) {
    CHECK_EQ(args[i].NumOutputs(), 1)
        << "Argument " << i << " is a tuple with " <<  args[i].NumOutputs()
        << " elements, scalar is required";
  }
  //原子symbol没有占位符,即没有分配输入,因此需要给输入赋值
  if (this->is_atomic()) {
    // atomic symbol do not have place holder for all the arguments
    std::vector<std::string> req_args = heads_[0].source->op->ListArguments();
    CHECK_LE(args.size(), req_args.size())
        << "Incorrect number of arguments, requires " << req_args.size()
        << ", provided " << args.size();
    heads_[0].source->inputs.resize(req_args.size());
    for (size_t i = 0; i < args.size(); ++i) {
      heads_[0].source->inputs[i] = args[i].heads_[0];
    }
    for (size_t i = args.size(); i < req_args.size(); ++i) {
      heads_[0].source->inputs[i] = DataEntry(
          std::make_shared<Node>(nullptr, DefaultVarName(name, req_args[i])), 0);
      // also copy attribute of operator over to automatically created variable
      if (heads_[0].source->attr.get() != nullptr) {
        heads_[0].source->inputs[i].source->attr.reset(
            new std::map<std::string, std::string>(*(heads_[0].source->attr)));
      }
    }
  } else {
    // find all the place holders
    size_t arg_counter = 0;
    std::unordered_map<Node*, const DataEntry*> replace_map;
    std::vector<std::pair<DataEntry*, const DataEntry*> > replace_plan;
    // replace map stores the existing replacement plan for arguments node
    this->DFSVisit([&arg_counter, &replace_map, &replace_plan, &args]
                   (const std::shared_ptr<Node> &node) {
        // visit all the childs, find possible replacement
        for (size_t i = 0; i < node->inputs.size(); ++i) {
          DataEntry *e = &(node->inputs[i]);
          if (e->source->is_variable()) {
            const DataEntry *target = nullptr;
            auto iter = replace_map.find(e->source.get());
            if (iter == replace_map.end()) {
              if (arg_counter < args.size()) {
                target = &(args[arg_counter].heads_[0]);
                replace_map[e->source.get()] = target;
              }
              ++arg_counter;
            } else {
              target = iter->second;
            }
            replace_plan.push_back(std::make_pair(e, target));
          }
        }
      });
    CHECK_EQ(args.size(), arg_counter)
        << "Incorrect number of arguments, requires " << arg_counter
        << ", provided " << args.size();
    // now run the replacement
    for (const auto& kv : replace_plan) {
      *(kv.first) = *(kv.second);
    }
  }
}

DFSVisit是用来访问每个网络节点的主要方法,FVisit 是实际对每个节点执行的操作,PostOrderDFSVisit是有向图深度优先搜索算法的真正实现:

template <typename GNode, typename HashType, typename FVisit,
          typename HashFunc, typename InDegree, typename GetInput>
void PostOrderDFSVisit(const std::vector<GNode>& heads, FVisit fvisit,
                       HashFunc hash, InDegree indegree, GetInput getinput) {
  std::vector<std::pair<GNode, uint32_t> > stack;
  std::unordered_set<HashType> visited;
  for (auto& head : heads) {
    HashType head_hash = hash(head);
    if (visited.count(head_hash) == 0) {
      stack.push_back(std::make_pair(head, 0));
      visited.insert(head_hash);
    }
    while (!stack.empty()) {
      std::pair<GNode, uint32_t>& back = stack.back();
      if (back.second == indegree(back.first)) {
        fvisit(back.first);
        stack.pop_back();
      } else {
        const GNode& input = getinput(back.first, back.second++);
        HashType input_hash = hash(input);
        if (visited.count(input_hash) == 0) {
          stack.push_back(std::make_pair(input, 0));
          visited.insert(input_hash);
        }
      }
    }
  }
}

这里采用后向深度优先遍历,即从尾节点向上推(与经典的递归深度优先搜索算法的从开始节点向后搜索不同),input相当于图的前驱节点,backward_source_node也将作为当前节点的最后一个input压入栈中参与搜索。

StaticGraph

class StaticGraph {
 public:
  struct DataEntry {
    /*! \brief the source node id in the computation graph */
    uint32_t source_id;
    /*! \brief index of output from the source. */
    uint32_t index;
  };
  /*!
   * \brief Operation Node in static graphs.
   *
   *  The reason we explicit support Backward node is to allow special treatment
   *  such as shape inference and state sharing with Forward pass.
   */
  struct Node {
    std::unique_ptr<OperatorProperty> op;
    std::string name;
    std::vector<DataEntry> inputs;
    int32_t backward_source_id;
    std::map<std::string, std::string> attr;
    /*!
     *
     *  Let n = inputs.size() - addto_index_.size();
     *    the output of the node is defined as:
     *  - out[j] = op(input[0:n]) for j not in addto_index_
     *  - out[addto_index_[i]] = op(input[0:n]) + inputs[n + i]
     */
    std::vector<uint32_t> addto_index;
  ......
  ......
  };
  /*! \brief all nodes in the graph */
  std::vector<Node> nodes;
  /*! \brief index of nodes that correspods to arguments */
  std::vector<uint32_t> arg_nodes;
  /*! \brief heads outputs of the graph */
  std::vector<DataEntry> heads;
  void Save(dmlc::JSONWriter *writer) const;
  void Load(dmlc::JSONReader *reader);

  std::vector<uint32_t> TopoSort() const;
  std::vector<uint32_t> PostDFSOrder(const std::vector<uint32_t>& head_nodes) const;
  void MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
                        std::vector<DataEntry> *arg_grads,
                        std::map<uint32_t, uint32_t>* out_mirror_map);
  ......
}

StaticGraph中重新定义了DataEntry与Node, 把Node的指针类型都用索引来替代,Node中添加了addto_index节省梯度计算时的内存。

StaticGraph是以一维数组存储图模型的,它根据symbol的拓补排序排列元素,TopoSort函数对输入个数为0的节点进行后向深度优先遍历得到拓补排序。
MakeBackwardPass用于构建后向通路,在函数开头会调用TopoSort,函数主体会对传进去的三个参数赋值,并且给nodes成员push_back新对象,push_back顺序为:
1. 如果有mirror,则将需要mirror的node复制并压入nodes。
2. 为每个heads的新建Node,作为梯度头结点压入。
3. 如果node有多个输出,并且该输出是其他op的输入,则调用CreateGradSumNode创建新的节点并压入。
3. 创建梯度后向节点并压入。
4. 如果参数梯度node是多个node的输入,则调用CreateGradSumNode创建新的节点并压入。

GraphExecutor

在以symbol合成图模型后,便可以转化为StaticGraph,然后用GraphExecutor执行StaticGraph,GraphExecutor类如下:

class GraphExecutor : public Executor {
 public:
  GraphExecutor() {}
  virtual ~GraphExecutor();
  void Forward(bool is_train) override;
  void PartialForward(bool is_train, int step, int *step_left) override;
  void Backward(const std::vector<NDArray> &head_grads) override;
  const std::vector<NDArray> &outputs() const override {
    return heads_ndarray_;
  }
  ......
  // 在Executor::Bind中调用, 且只调用一次.
  inline void Init(Symbol symbol,
                   const Context& default_ctx,
                   const std::map<std::string, Context>& ctx_map,
                   const std::vector<NDArray> &in_args,
                   const std::vector<NDArray> &arg_grad_store,
                   const std::vector<OpReqType> &grad_req_type,
                   const std::vector<NDArray> &aux_states,
                   Executor* shared_exec = nullptr) {
    enable_inplace_allocation_ = dmlc::GetEnv("MXNET_EXEC_ENABLE_INPLACE", true);
    prefer_bulk_execution_ = dmlc::GetEnv("MXNET_EXEC_PREFER_BULK_EXEC", true);
    if (shared_exec != NULL) {
      GraphExecutor* gexec = dynamic_cast<GraphExecutor*>(shared_exec);
      CHECK(gexec) << "Input executor for sharing memory must have GraphExecutor type.";
      shared_mem_ = gexec->shared_mem_;
    } else {
      shared_mem_ = std::make_shared<GraphStoragePool>();
    }

    CHECK_EQ(grad_req_type.size(), arg_grad_store.size());
    bool need_backward = false;
    for (auto req : grad_req_type) {
      if (req != kNullOp) need_backward = true;
    }
    this->InitGraph(symbol, default_ctx, ctx_map,
                    in_args, arg_grad_store, grad_req_type,
                    need_backward);
    this->InitDataEntryInfo(in_args, arg_grad_store, grad_req_type, aux_states);
    this->InitOperators();
    this->InitDataEntryMemory();
    this->InitResources();
    this->InitCachedOps();
  }

 protected:
  // internal class of wrapping BackwardOp as ForwardOp
  class BackwardOpWrapper;
  // type of data entry
  enum DataEntryType {
    // memory is bound by external NDArray in Bind
    kBindByExternal,
    // to be bound by external NDArray in Forward and Backward
    kTobeBindByExternal,
    // internal memory, allocated
    kInternalAllocated,
    // internal memory, to be allocated
    kNotInitialized
  };
  // Additional information about each data entry
  struct DataEntryInfo {
    // the actual data for the entry
    NDArray data;
    // write request to this entry
    OpReqType op_req;
    // the operatio node that will take
    // this DataEntry as inplace input
    int inplace_op_id;
    // data entry type
    DataEntryType type;
    // shape of this entry
    TShape shape;
    // data type of this entry
    int type_flag;
    // storage id from allocator if it is internal allocation.
    GraphStorageAllocator::StorageID storage_id;
    // reference count on how many times this entry is being used.
    // That is how many operators and heads need this DataEntry
    // this is a temporal variable that is used during initialization.
    uint32_t temp_ref_count;
    // real permanent ref count
    uint32_t ref_count;
    // constructor
    DataEntryInfo()
        : op_req(kNullOp),
          inplace_op_id(-1),
          type(kNotInitialized),
          storage_id(GraphStorageAllocator::kBadStorageID),
          temp_ref_count(0), ref_count(0) {}
  };
  // all the information needed to push the op to engine
  struct OpExecEntry {
    // execution function for
    Engine::AsyncFn exec_fun;
    // variables to read from
    std::vector<Engine::VarHandle> use_vars;
    // variables to mutate
    std::vector<Engine::VarHandle> mutate_vars;
    // constructor
    OpExecEntry() : exec_fun(nullptr) {}
  };
  // Information about operational node
  struct OpNode {
    // whether this op node is activated
    bool activated;
    // the context of the node
    Context ctx;
    // data entry information about outputs of op
    std::vector<DataEntryInfo> outputs;
    // auxiliary data information of op
    std::vector<DataEntryInfo> aux_states;
    // The following parts are constructed in InitOpNodes
    // the real operator
    std::shared_ptr<Operator> op;
    // op context, that is defined for this op.
    OpContext op_ctx;
    // executor, this is only allocated for nodes
    // whose inputs, outputs are pre-defined.
    // otherwise cached_exec.exec_fun == nullptr
    OpExecEntry cached_exec;
    // cached operator handle
    Engine::OprHandle cached_opr{nullptr};
    // constructor
    OpNode() : activated(false) {}
    // Manual option for delete operator
    // need to do this before delete NDArrays
    inline void DeleteOperator() {
      if (cached_opr != nullptr) {
        Engine::Get()->DeleteOperator(cached_opr);
        cached_opr = nullptr;
      }
    }
  };
  ......
  ......
  inline OpExecEntry GetOpExecEntry(uint32_t node_id);
  // run ops from topo order start to end
  void RunOps(bool is_train, size_t topo_start, size_t topo_end);
  // internal computational graph
  StaticGraph graph_;
  // topological order of nodes in computation graph
  // backward nodes always follow forward nodes
  std::vector<uint32_t> topo_order_;
  // whether to enable inplace space
  bool enable_inplace_allocation_;
  // total allocated space in bytes
  size_t total_allocated_bytes_;
  // total allocated temp space
  size_t total_allocated_temp_;
  // number of forward nodes in the graph
  size_t num_forward_nodes_;
  // whether to enable bulk execution
  bool prefer_bulk_execution_;
  // head gradient node in the graph, if there is backward pass
  std::vector<uint32_t> head_grad_nodes_;
  // mirror map of nodes, experimental feature, normally can be ignored.
  std::map<uint32_t, uint32_t> mirror_source_map_;
  // argument node in the graph, if there is backward pass
  std::vector<StaticGraph::DataEntry> arg_grads_;
  // 操作节点,与graph_.nodes一一对应
  std::vector<OpNode> op_nodes_;
  // head NDArrays
  std::vector<NDArray> heads_ndarray_;
  // shared NDArrays
  std::shared_ptr<GraphStoragePool> shared_mem_;
  // monitor call back
  std::function<void(const char*, void*)> monitor_callback_;
  // cached segment operator
  std::unordered_map<size_t, Engine::OprHandle> cached_seg_opr_;
};  // class GraphExecutor

GraphExecutor中,Forward和Backward是图模型的前后向运算函数,它们都要调用一个核心函数RunOps,该函数如下:

void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
  // heurestic, only enable bulk on forward only
  bool bulk_exec = prefer_bulk_execution_ && !monitor_callback_
      && topo_start == 0 && num_forward_nodes_ == topo_order_.size();
  //如果批量执行
  if (bulk_exec) {
    // encode things into a key
    size_t key = topo_start * op_nodes_.size() + topo_end;
    if (cached_seg_opr_.count(key) == 0) {
      cached_seg_opr_[key] = this->CreateCachedOpr(topo_start, topo_end);
      if (cached_seg_opr_.at(key) != nullptr) {
        LOG(INFO) << "Created bulk execution on segment ["
                  << topo_start << ", " << topo_end << ")";
      }
    }
    auto cached_op = cached_seg_opr_.at(key);
    if (cached_op != nullptr) {
      Context* pctx = nullptr;
      for (size_t i = topo_start; i < topo_end; ++i) {
        uint32_t nid = topo_order_[i];
        if (!op_nodes_[nid].activated) continue;
        if (graph_.nodes[nid].is_variable()) continue;
        OpNode& opnode = op_nodes_[nid];
        opnode.op_ctx.is_train = is_train;
        pctx = &(opnode.ctx);
      }
      Engine::Get()->Push(cached_op, *pctx);
      return;
    }
  }

  for (size_t i = topo_start; i < topo_end; ++i) {
    uint32_t nid = topo_order_[i];
    if (!op_nodes_[nid].activated) continue;
    if (graph_.nodes[nid].is_variable()) continue;
    OpNode& opnode = op_nodes_[nid];
    // special handle cross device copy op
    if (opnode.op->exec_type() == Operator::kCrossDeviceCopy) {
      CHECK_EQ(graph_.nodes[nid].inputs.size(), 1);
      CHECK_EQ(opnode.outputs.size(), 1);
      auto in = graph_.nodes[nid].inputs[0];
      CopyFromTo(op_nodes_[in.source_id].outputs[in.index].data,
                 &(opnode.outputs[0].data));
      continue;
    }
    opnode.op_ctx.is_train = is_train;
    if (opnode.cached_opr != nullptr) {
      Engine::Get()->Push(opnode.cached_opr, opnode.ctx);
    } else {
      auto exec = GetOpExecEntry(nid);
      Engine::Get()->PushAsync(
          exec.exec_fun,
          opnode.ctx,
          exec.use_vars,
          exec.mutate_vars,
          FnProperty::kNormal);
    }
    if (monitor_callback_) {
      std::vector<std::string> output_names;
      if (graph_.nodes[nid].is_forward()) {
        output_names = graph_.nodes[nid].op->ListOutputs();
      } else {
        int source_id = graph_.nodes[nid].backward_source_id;
        output_names = graph_.nodes[source_id].op->ListArguments();
      }
      for (index_t i = 0; i < opnode.outputs.size(); ++i) {
        NDArray out_data = opnode.outputs[i].data;
        std::string name = graph_.nodes[nid].name + "_" + output_names[i];
        NDArray *cpy = new NDArray(out_data);
        this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
      }
    }
  }
}

RunOps中如果执行批量操作,则需要创建缓存操作集,然后将其压入执行引擎;否则就要一个节点一个节点地创建操作并压入执行引擎,GetOpExecEntry是用来获得每一个操作节点的执行函数以便压入执行引擎,GetOpExecEntry的核心操作是调用Forawrd函数,在图模型中,前向节点的Forward就是调用各种Operator(如Convolution)的Forward函数,后向节点则将其对应的Operator的Backward封装到BackwardOpWrapper的Forward函数中来统一GetOpExecEntry中的调用操作。

GraphExecutor中另一类函数,就是在Init中调用的各种初始化函数,是用来初始化图模型的。

InitGraph主要包含三个主要步骤:
1. 调用MakeBackwardPass构造backward通路。
2. 调用AssignContext为每个节点分配设备环境,分配结果保存在op_nodes_中,如果某节点输入与该节点的设备不一致,则需要创建copy节点,并插入到nodes尾部。
3. 对上述两个步骤后的nodes进行特定的拓补排序,以保证backward在forward之后,将结果保存在topo_order_中。

InitDataEntryInfo主要是根据传入的参数推断shape和type,并填充op_nodes_的data、type、type_flag、ref_count、activated等属性。

InitOperators将对InitDataEntryInfo激活的OpNode以及非参数Node执行初始化,前向节点调用CreateOperatorEx初始化OpNode的Operator属性,后向节点调用BackwardOpWrapper构造函数初始化OpNode的Operator属性。

InitDataEntryMemory用GraphStorageAllocator类为GraphExecutor分配内存,GraphStorageAllocator分配内存分为两步,第一步由GraphExecutor调用request和release根据依赖计划将要请求和释放的资源,每一个request调用会返回一个StorageID来标记分配给每一个DataEntryInfo的内存块;第二步调用InitStorages分配真实的内存。

InitDataEntryMemory函数块首先检查nodes输入以及op_nodes_输出的状态,这些状态保存在DataEntryInfo中,即type和temp_ref_count属性;nodes输入(type与temp_ref_count)必须已经处于初始化后的状态,op_nodes_输出(type)必须尚未被内部分配。检查通过后,分配内存前先调用GetInplaceOption获得有inplace操作的DataEntryInfo项,并修改相应输入输出状态。然后便可以采用GraphStorageAllocator的两步走分配内存。最后为GraphExecutor的heads_ndarray_压入数据。

InitResources用于分配临时内存,例如像cudnnConvolutionForward这样的操作需要一个workspace来辅助计算。InitResources通过调用GetResource获取某一个具体ResourceRequest,这个ResourceRequest又是有具体的operator调用ForwardResource或BackwardResource生成的。mxnet用不同颜色标记图模型中可以并行的node,因此在请求临时空间资源时,将同一种颜色的node的请求压入OpNode的属性OpContext的Resource属性中,每种颜色的node在第一次遇到请求都会用ResourceManger类创建新的Resource。而颜色的分组是通过调用ColorNodeGroup完成的,一共有max_ncolor种颜色,ColorNodeGroup中有max_ncolor次循环,每次循环找到一种颜色的执行全路径。

InitCachedOps为OpNode填充cached_opr属性,前提是这个OpNode的输出及其对应的Node的输入的DataEntryType并不是被外部绑定的,是预定义了的,且这个OpNode的的Operator不是交叉设备的拷贝操作。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值