mxnet的前后向计算是以图模型构建的,它有两个重要的类结构,一个是Symbol,另一个是StaticGraph。Symbol和StaticGraph可以相互转换,Symbol提供了灵活的方式来合成节点,StaticGraph则包含了实际的配置。
Symbol
symbol类表示网络结构动态生成的符号计算图,symbol基本结构如下:
class Symbol {
protected:
struct Node;
struct DataEntry {
std::shared_ptr<Node> source;
//source输出的index
uint32_t index;
......
};
std::vector<DataEntry> heads_;
......
void Compose(const std::vector<Symbol>& args,const std::string& name);
void Compose(const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name);
Symbol Grad(const std::vector<std::string>& wrt) const;
}
其中head_表示symbol的输出节点, 而DataEntry类与Node类的定义存在交叉,Node类的定义如下:
struct Symbol::Node {
std::unique_ptr<OperatorProperty> op;
std::string name;
std::vector<DataEntry> inputs;
//当前节点的源节点
std::shared_ptr<Symbol::Node> backward_source_node;
std::unique_ptr<std::map<std::string, std::string> > attr;
......
......
template<typename FVisit>
inline void Symbol::DFSVisit(FVisit fvisit) const {
typedef const std::shared_ptr<Node>* GNode;
std::vector<GNode> head_nodes(heads_.size());
std::transform(heads_.begin(), heads_.end(), head_nodes.begin(),
[](const DataEntry& e)->GNode {
return &e.source;
});
graph::PostOrderDFSVisit<GNode, Node*>(
head_nodes,
[fvisit](GNode n) { fvisit(*n); }, // FVisit
[](GNode n)->Node* { return n->get(); }, // HashFunc
[](GNode n)->uint32_t { return (*n)->inputs.size() +
static_cast<int>((*n)->is_backward()); }, // InDegree
[](GNode n, uint32_t index)->GNode { // GetInput
if (index < (*n)->inputs.size()) {
return &(*n)->inputs.at(index).source;
} else {
return &(*n)->backward_source_node;
}
});
}
......
......
}
Node有三种类型:普通节点、操作属性、可变的。可变的节点表示操作为空的参数节点。
Symbol::Compose用于合成节点,以提供完整的位置参数为例:
void Symbol::Compose(const std::vector<Symbol>& args,
const std::string& name) {
// CHECK_EQ(NumOutputs(), 1) << "Only composition of value function is supported currently";
CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed";
heads_[0].source->name = name;
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i].NumOutputs(), 1)
<< "Argument " << i << " is a tuple with " << args[i].NumOutputs()
<< " elements, scalar is required";
}
//原子symbol没有占位符,即没有分配输入,因此需要给输入赋值
if (this->is_atomic()) {
// atomic symbol do not have place holder for all the arguments
std::vector<std::string> req_args = heads_[0].source->op->ListArguments();
CHECK_LE(args.size(), req_args.size())
<< "Incorrect number of arguments, requires " << req_args.size()
<< ", provided " << args.size();
heads_[0].source->inputs.resize(req_args.size());
for (size_t i = 0; i < args.size(); ++i) {
heads_[0].source->inputs[i] = args[i].heads_[0];
}
for (size_t i = args.size(); i < req_args.size(); ++i) {
heads_[0].source->inputs[i] = DataEntry(
std::make_shared<Node>(nullptr, DefaultVarName(name, req_args[i])), 0);
// also copy attribute of operator over to automatically created variable
if (heads_[0].source->attr.get() != nullptr) {
heads_[0].source->inputs[i].source->attr.reset(
new std::map<std::string, std::string>(*(heads_[0].source->attr)));
}
}
} else {
// find all the place holders
size_t arg_counter = 0;
std::unordered_map<Node*, const DataEntry*> replace_map;
std::vector<std::pair<DataEntry*, const DataEntry*> > replace_plan;
// replace map stores the existing replacement plan for arguments node
this->DFSVisit([&arg_counter, &replace_map, &replace_plan, &args]
(const std::shared_ptr<Node> &node) {
// visit all the childs, find possible replacement
for (size_t i = 0; i < node->inputs.size(); ++i) {
DataEntry *e = &(node->inputs[i]);
if (e->source->is_variable()) {
const DataEntry *target = nullptr;
auto iter = replace_map.find(e->source.get());
if (iter == replace_map.end()) {
if (arg_counter < args.size()) {
target = &(args[arg_counter].heads_[0]);
replace_map[e->source.get()] = target;
}
++arg_counter;
} else {
target = iter->second;
}
replace_plan.push_back(std::make_pair(e, target));
}
}
});
CHECK_EQ(args.size(), arg_counter)
<< "Incorrect number of arguments, requires " << arg_counter
<< ", provided " << args.size();
// now run the replacement
for (const auto& kv : replace_plan) {
*(kv.first) = *(kv.second);
}
}
}
DFSVisit是用来访问每个网络节点的主要方法,FVisit 是实际对每个节点执行的操作,PostOrderDFSVisit是有向图深度优先搜索算法的真正实现:
template <typename GNode, typename HashType, typename FVisit,
typename HashFunc, typename InDegree, typename GetInput>
void PostOrderDFSVisit(const std::vector<GNode>& heads, FVisit fvisit,
HashFunc hash, InDegree indegree, GetInput getinput) {
std::vector<std::pair<GNode, uint32_t> > stack;
std::unordered_set<HashType> visited;
for (auto& head : heads) {
HashType head_hash = hash(head);
if (visited.count(head_hash) == 0) {
stack.push_back(std::make_pair(head, 0));
visited.insert(head_hash);
}
while (!stack.empty()) {
std::pair<GNode, uint32_t>& back = stack.back();
if (back.second == indegree(back.first)) {
fvisit(back.first);
stack.pop_back();
} else {
const GNode& input = getinput(back.first, back.second++);
HashType input_hash = hash(input);
if (visited.count(input_hash) == 0) {
stack.push_back(std::make_pair(input, 0));
visited.insert(input_hash);
}
}
}
}
}
这里采用后向深度优先遍历,即从尾节点向上推(与经典的递归深度优先搜索算法的从开始节点向后搜索不同),input相当于图的前驱节点,backward_source_node也将作为当前节点的最后一个input压入栈中参与搜索。
StaticGraph
class StaticGraph {
public:
struct DataEntry {
/*! \brief the source node id in the computation graph */
uint32_t source_id;
/*! \brief index of output from the source. */
uint32_t index;
};
/*!
* \brief Operation Node in static graphs.
*
* The reason we explicit support Backward node is to allow special treatment
* such as shape inference and state sharing with Forward pass.
*/
struct Node {
std::unique_ptr<OperatorProperty> op;
std::string name;
std::vector<DataEntry> inputs;
int32_t backward_source_id;
std::map<std::string, std::string> attr;
/*!
*
* Let n = inputs.size() - addto_index_.size();
* the output of the node is defined as:
* - out[j] = op(input[0:n]) for j not in addto_index_
* - out[addto_index_[i]] = op(input[0:n]) + inputs[n + i]
*/
std::vector<uint32_t> addto_index;
......
......
};
/*! \brief all nodes in the graph */
std::vector<Node> nodes;
/*! \brief index of nodes that correspods to arguments */
std::vector<uint32_t> arg_nodes;
/*! \brief heads outputs of the graph */
std::vector<DataEntry> heads;
void Save(dmlc::JSONWriter *writer) const;
void Load(dmlc::JSONReader *reader);
std::vector<uint32_t> TopoSort() const;
std::vector<uint32_t> PostDFSOrder(const std::vector<uint32_t>& head_nodes) const;
void MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
std::vector<DataEntry> *arg_grads,
std::map<uint32_t, uint32_t>* out_mirror_map);
......
}
StaticGraph中重新定义了DataEntry与Node, 把Node的指针类型都用索引来替代,Node中添加了addto_index节省梯度计算时的内存。
StaticGraph是以一维数组存储图模型的,它根据symbol的拓补排序排列元素,TopoSort函数对输入个数为0的节点进行后向深度优先遍历得到拓补排序。
MakeBackwardPass用于构建后向通路,在函数开头会调用TopoSort,函数主体会对传进去的三个参数赋值,并且给nodes成员push_back新对象,push_back顺序为:
1. 如果有mirror,则将需要mirror的node复制并压入nodes。
2. 为每个heads的新建Node,作为梯度头结点压入。
3. 如果node有多个输出,并且该输出是其他op的输入,则调用CreateGradSumNode创建新的节点并压入。
3. 创建梯度后向节点并压入。
4. 如果参数梯度node是多个node的输入,则调用CreateGradSumNode创建新的节点并压入。
GraphExecutor
在以symbol合成图模型后,便可以转化为StaticGraph,然后用GraphExecutor执行StaticGraph,GraphExecutor类如下:
class GraphExecutor : public Executor {
public:
GraphExecutor() {}
virtual ~GraphExecutor();
void Forward(bool is_train) override;
void PartialForward(bool is_train, int step, int *step_left) override;
void Backward(const std::vector<NDArray> &head_grads) override;
const std::vector<NDArray> &outputs() const override {
return heads_ndarray_;
}
......
// 在Executor::Bind中调用, 且只调用一次.
inline void Init(Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray> &in_args,
const std::vector<NDArray> &arg_grad_store,
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states,
Executor* shared_exec = nullptr) {
enable_inplace_allocation_ = dmlc::GetEnv("MXNET_EXEC_ENABLE_INPLACE", true);
prefer_bulk_execution_ = dmlc::GetEnv("MXNET_EXEC_PREFER_BULK_EXEC", true);
if (shared_exec != NULL) {
GraphExecutor* gexec = dynamic_cast<GraphExecutor*>(shared_exec);
CHECK(gexec) << "Input executor for sharing memory must have GraphExecutor type.";
shared_mem_ = gexec->shared_mem_;
} else {
shared_mem_ = std::make_shared<GraphStoragePool>();
}
CHECK_EQ(grad_req_type.size(), arg_grad_store.size());
bool need_backward = false;
for (auto req : grad_req_type) {
if (req != kNullOp) need_backward = true;
}
this->InitGraph(symbol, default_ctx, ctx_map,
in_args, arg_grad_store, grad_req_type,
need_backward);
this->InitDataEntryInfo(in_args, arg_grad_store, grad_req_type, aux_states);
this->InitOperators();
this->InitDataEntryMemory();
this->InitResources();
this->InitCachedOps();
}
protected:
// internal class of wrapping BackwardOp as ForwardOp
class BackwardOpWrapper;
// type of data entry
enum DataEntryType {
// memory is bound by external NDArray in Bind
kBindByExternal,
// to be bound by external NDArray in Forward and Backward
kTobeBindByExternal,
// internal memory, allocated
kInternalAllocated,
// internal memory, to be allocated
kNotInitialized
};
// Additional information about each data entry
struct DataEntryInfo {
// the actual data for the entry
NDArray data;
// write request to this entry
OpReqType op_req;
// the operatio node that will take
// this DataEntry as inplace input
int inplace_op_id;
// data entry type
DataEntryType type;
// shape of this entry
TShape shape;
// data type of this entry
int type_flag;
// storage id from allocator if it is internal allocation.
GraphStorageAllocator::StorageID storage_id;
// reference count on how many times this entry is being used.
// That is how many operators and heads need this DataEntry
// this is a temporal variable that is used during initialization.
uint32_t temp_ref_count;
// real permanent ref count
uint32_t ref_count;
// constructor
DataEntryInfo()
: op_req(kNullOp),
inplace_op_id(-1),
type(kNotInitialized),
storage_id(GraphStorageAllocator::kBadStorageID),
temp_ref_count(0), ref_count(0) {}
};
// all the information needed to push the op to engine
struct OpExecEntry {
// execution function for
Engine::AsyncFn exec_fun;
// variables to read from
std::vector<Engine::VarHandle> use_vars;
// variables to mutate
std::vector<Engine::VarHandle> mutate_vars;
// constructor
OpExecEntry() : exec_fun(nullptr) {}
};
// Information about operational node
struct OpNode {
// whether this op node is activated
bool activated;
// the context of the node
Context ctx;
// data entry information about outputs of op
std::vector<DataEntryInfo> outputs;
// auxiliary data information of op
std::vector<DataEntryInfo> aux_states;
// The following parts are constructed in InitOpNodes
// the real operator
std::shared_ptr<Operator> op;
// op context, that is defined for this op.
OpContext op_ctx;
// executor, this is only allocated for nodes
// whose inputs, outputs are pre-defined.
// otherwise cached_exec.exec_fun == nullptr
OpExecEntry cached_exec;
// cached operator handle
Engine::OprHandle cached_opr{nullptr};
// constructor
OpNode() : activated(false) {}
// Manual option for delete operator
// need to do this before delete NDArrays
inline void DeleteOperator() {
if (cached_opr != nullptr) {
Engine::Get()->DeleteOperator(cached_opr);
cached_opr = nullptr;
}
}
};
......
......
inline OpExecEntry GetOpExecEntry(uint32_t node_id);
// run ops from topo order start to end
void RunOps(bool is_train, size_t topo_start, size_t topo_end);
// internal computational graph
StaticGraph graph_;
// topological order of nodes in computation graph
// backward nodes always follow forward nodes
std::vector<uint32_t> topo_order_;
// whether to enable inplace space
bool enable_inplace_allocation_;
// total allocated space in bytes
size_t total_allocated_bytes_;
// total allocated temp space
size_t total_allocated_temp_;
// number of forward nodes in the graph
size_t num_forward_nodes_;
// whether to enable bulk execution
bool prefer_bulk_execution_;
// head gradient node in the graph, if there is backward pass
std::vector<uint32_t> head_grad_nodes_;
// mirror map of nodes, experimental feature, normally can be ignored.
std::map<uint32_t, uint32_t> mirror_source_map_;
// argument node in the graph, if there is backward pass
std::vector<StaticGraph::DataEntry> arg_grads_;
// 操作节点,与graph_.nodes一一对应
std::vector<OpNode> op_nodes_;
// head NDArrays
std::vector<NDArray> heads_ndarray_;
// shared NDArrays
std::shared_ptr<GraphStoragePool> shared_mem_;
// monitor call back
std::function<void(const char*, void*)> monitor_callback_;
// cached segment operator
std::unordered_map<size_t, Engine::OprHandle> cached_seg_opr_;
}; // class GraphExecutor
GraphExecutor中,Forward和Backward是图模型的前后向运算函数,它们都要调用一个核心函数RunOps,该函数如下:
void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
// heurestic, only enable bulk on forward only
bool bulk_exec = prefer_bulk_execution_ && !monitor_callback_
&& topo_start == 0 && num_forward_nodes_ == topo_order_.size();
//如果批量执行
if (bulk_exec) {
// encode things into a key
size_t key = topo_start * op_nodes_.size() + topo_end;
if (cached_seg_opr_.count(key) == 0) {
cached_seg_opr_[key] = this->CreateCachedOpr(topo_start, topo_end);
if (cached_seg_opr_.at(key) != nullptr) {
LOG(INFO) << "Created bulk execution on segment ["
<< topo_start << ", " << topo_end << ")";
}
}
auto cached_op = cached_seg_opr_.at(key);
if (cached_op != nullptr) {
Context* pctx = nullptr;
for (size_t i = topo_start; i < topo_end; ++i) {
uint32_t nid = topo_order_[i];
if (!op_nodes_[nid].activated) continue;
if (graph_.nodes[nid].is_variable()) continue;
OpNode& opnode = op_nodes_[nid];
opnode.op_ctx.is_train = is_train;
pctx = &(opnode.ctx);
}
Engine::Get()->Push(cached_op, *pctx);
return;
}
}
for (size_t i = topo_start; i < topo_end; ++i) {
uint32_t nid = topo_order_[i];
if (!op_nodes_[nid].activated) continue;
if (graph_.nodes[nid].is_variable()) continue;
OpNode& opnode = op_nodes_[nid];
// special handle cross device copy op
if (opnode.op->exec_type() == Operator::kCrossDeviceCopy) {
CHECK_EQ(graph_.nodes[nid].inputs.size(), 1);
CHECK_EQ(opnode.outputs.size(), 1);
auto in = graph_.nodes[nid].inputs[0];
CopyFromTo(op_nodes_[in.source_id].outputs[in.index].data,
&(opnode.outputs[0].data));
continue;
}
opnode.op_ctx.is_train = is_train;
if (opnode.cached_opr != nullptr) {
Engine::Get()->Push(opnode.cached_opr, opnode.ctx);
} else {
auto exec = GetOpExecEntry(nid);
Engine::Get()->PushAsync(
exec.exec_fun,
opnode.ctx,
exec.use_vars,
exec.mutate_vars,
FnProperty::kNormal);
}
if (monitor_callback_) {
std::vector<std::string> output_names;
if (graph_.nodes[nid].is_forward()) {
output_names = graph_.nodes[nid].op->ListOutputs();
} else {
int source_id = graph_.nodes[nid].backward_source_id;
output_names = graph_.nodes[source_id].op->ListArguments();
}
for (index_t i = 0; i < opnode.outputs.size(); ++i) {
NDArray out_data = opnode.outputs[i].data;
std::string name = graph_.nodes[nid].name + "_" + output_names[i];
NDArray *cpy = new NDArray(out_data);
this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
}
}
}
}
RunOps中如果执行批量操作,则需要创建缓存操作集,然后将其压入执行引擎;否则就要一个节点一个节点地创建操作并压入执行引擎,GetOpExecEntry是用来获得每一个操作节点的执行函数以便压入执行引擎,GetOpExecEntry的核心操作是调用Forawrd函数,在图模型中,前向节点的Forward就是调用各种Operator(如Convolution)的Forward函数,后向节点则将其对应的Operator的Backward封装到BackwardOpWrapper的Forward函数中来统一GetOpExecEntry中的调用操作。
GraphExecutor中另一类函数,就是在Init中调用的各种初始化函数,是用来初始化图模型的。
InitGraph主要包含三个主要步骤:
1. 调用MakeBackwardPass构造backward通路。
2. 调用AssignContext为每个节点分配设备环境,分配结果保存在op_nodes_中,如果某节点输入与该节点的设备不一致,则需要创建copy节点,并插入到nodes尾部。
3. 对上述两个步骤后的nodes进行特定的拓补排序,以保证backward在forward之后,将结果保存在topo_order_中。
InitDataEntryInfo主要是根据传入的参数推断shape和type,并填充op_nodes_的data、type、type_flag、ref_count、activated等属性。
InitOperators将对InitDataEntryInfo激活的OpNode以及非参数Node执行初始化,前向节点调用CreateOperatorEx初始化OpNode的Operator属性,后向节点调用BackwardOpWrapper构造函数初始化OpNode的Operator属性。
InitDataEntryMemory用GraphStorageAllocator类为GraphExecutor分配内存,GraphStorageAllocator分配内存分为两步,第一步由GraphExecutor调用request和release根据依赖计划将要请求和释放的资源,每一个request调用会返回一个StorageID来标记分配给每一个DataEntryInfo的内存块;第二步调用InitStorages分配真实的内存。
InitDataEntryMemory函数块首先检查nodes输入以及op_nodes_输出的状态,这些状态保存在DataEntryInfo中,即type和temp_ref_count属性;nodes输入(type与temp_ref_count)必须已经处于初始化后的状态,op_nodes_输出(type)必须尚未被内部分配。检查通过后,分配内存前先调用GetInplaceOption获得有inplace操作的DataEntryInfo项,并修改相应输入输出状态。然后便可以采用GraphStorageAllocator的两步走分配内存。最后为GraphExecutor的heads_ndarray_压入数据。
InitResources用于分配临时内存,例如像cudnnConvolutionForward这样的操作需要一个workspace来辅助计算。InitResources通过调用GetResource获取某一个具体ResourceRequest,这个ResourceRequest又是有具体的operator调用ForwardResource或BackwardResource生成的。mxnet用不同颜色标记图模型中可以并行的node,因此在请求临时空间资源时,将同一种颜色的node的请求压入OpNode的属性OpContext的Resource属性中,每种颜色的node在第一次遇到请求都会用ResourceManger类创建新的Resource。而颜色的分组是通过调用ColorNodeGroup完成的,一共有max_ncolor种颜色,ColorNodeGroup中有max_ncolor次循环,每次循环找到一种颜色的执行全路径。
InitCachedOps为OpNode填充cached_opr属性,前提是这个OpNode的输出及其对应的Node的输入的DataEntryType并不是被外部绑定的,是预定义了的,且这个OpNode的的Operator不是交叉设备的拷贝操作。