1 概述
TensorFlow后端分为四层,运行时层、计算层、通信层、设备层。运行时作为第一层,实现了session管理、graph管理等很多重要的逻辑,是十分关键的一层。根据任务分布的不同,运行时又分为本地运行时和分布式运行时。本地运行时,所有任务运行于本地同一进程内。而分布式运行时,则允许任务运行在不同机器上。
Tensorflow的运行,通过session搭建了前后端沟通的桥梁,前端几乎所有操作都是通过session进行。session的生命周期由创建、运行、关闭、销毁组成,前文已经详细讲述过。可以将session看做TensorFlow运行的载体。而TensorFlow运行的核心对象,则是计算图Graph。它由计算算子和计算数据两部分构成,可以完整描述整个计算内容。Graph的生命周期包括构建和传递、剪枝、分裂、执行等步骤,本文会详细讲解。理解TensorFlow的运行时,重点就是理解会话session和计算图Graph。
本地运行时,client master和worker都在本地机器的同一进程内,均通过DirectSession类来描述。由于在同一进程内,三者间可以共享内存,通过DirectSession的相关函数实现调用。
client前端直接面向用户,负责session的创建,计算图Graph的构造。并通过session.run()将Graph序列化后传递给master。master收到后,先反序列化得到Graph,然后根据反向依赖关系,得到几个最小依赖子图,这一步称为剪枝。之后master根据可运行的设备情况,将子图分裂到不同设备上,从而可以并发执行,这一步称为分裂。最后,由每个设备上的worker并行执行分裂后的子图,得到计算结果后返回。
2 Graph构建和传递
session.run()开启了后端Graph的构建和传递。在前文session生命周期的讲解中,session.run()时会先调用_extend_graph()将要运行的Operation添加到Graph中,然后再启动运行过程。extend_graph()会先将graph序列化,得到graph_def,然后调用后端的TF_ExtendGraph()方法。下面我们从 http://c_api.cc 中的TF_ExtendGraph()看起。
// 增加节点到graph中,proto为序列化后的graph
void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
size_t proto_len, TF_Status* status) {
GraphDef g;
// 先将proto转换为GrapDef。graphDef是图的序列化表示,反序列化在后面。
if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
status->status = InvalidArgument("Invalid GraphDef");
return;
}
// 再调用session的extend方法。根据创建的不同session类型,多态调用不同方法。
status->status = s->session->Extend(g);
}
后端系统根据生成的Session类型,多态的调用Extend方法。如果是本地session,则调用DirectSession的Extend()方法。下面看DirectSession的Extend()方法。
Status DirectSession::Extend(const GraphDef& graph) {
// 保证线程安全,然后调用ExtendLocked()
mutex_lock l(graph_def_lock_);
return ExtendLocked(graph);
}
// 主要任务就是创建GraphExecutionState对象。
Status DirectSession::ExtendLocked(const GraphDef& graph) {
bool already_initialized;
if (already_initialized) {
TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
// 创建GraphExecutionState
std::unique_ptr<GraphExecutionState> state;
TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
execution_state_.swap(state);
}
return Status::OK();
}
最终创建了GraphExecutionState对象。它主要工作有
- 负责将GraphDef反序列化为graph,从而构造出graph。在初始化方法InitBaseGraph()中
- 执行部分op编排工作,在初始化方法InitBaseGraph()中
Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
const GraphDef* graph_def = &original_graph_def_;
// graphDef反序列化得到graph
std::unique_ptr<Graph> new_graph(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, *graph_def, new_graph.get()));
// 恢复有状态的节点
RestoreStatefulNodes(new_graph.get());
// 构造优化器的选项 optimization_options
GraphOptimizationPassOptions optimization_options;
optimization_options.session_options = session_options_;
optimization_options.graph = &new_graph;
optimization_options.flib_def = flib_def_.get();
optimization_options.device_set = device_set_;
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
// plaer执行op编排
Placer placer(new_graph.get(), device_set_, session_option