本节介绍tensorflow中的graph,在c_api.cc 中有创建graph的例子,可以从这个为切入点,探索graph的使用。
TF_Graph
在c_api.cc中,创建graph的代码如下:
TF_Graph* TF_NewGraph() { return new TF_Graph; }
TF_Graph::TF_Graph()
: graph(tensorflow::OpRegistry::Global()),
refiner(graph.versions().producer(), graph.op_registry()),
delete_requested(false),
parent(nullptr),
parent_inputs(nullptr) {
// Tell the shape refiner to also run shape inference on functions.
refiner.set_function_library_for_shape_inference(&graph.flib_def());
}
通过调用TF_Graph来创建一个图,TF_Graph的定义在tensorflow/c/c_api_internal.h
struct TF_Graph {
TF_Graph();
mutable tensorflow::mutex mu;
tensorflow::Graph graph TF_GUARDED_BY(mu);
// Runs shape inference.
tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);
// Maps from name of an operation to the Node* in 'graph'.
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
TF_GUARDED_BY(mu);
// The keys of this map are all the active sessions using this graph. Each
// value records whether the graph has been mutated since the corresponding
// session has been run (this is detected in RecordMutation function). If the
// string is empty, no mutation has occurred. Otherwise the string is a
// description of the mutation suitable for returning to the user.
//
// Sessions are added to this map in TF_NewSession, and removed in
// TF_DeleteSession.
// TF_Graph may only / must be deleted when
// sessions.size() == 0 && delete_requested == true
//
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
// status, this should be reverted when possible.
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
TF_GUARDED_BY(mu);
bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph
// Used to link graphs contained in TF_WhileParams to the parent graph that
// will eventually contain the full while loop.
TF_Graph* parent;
TF_Output* parent_inputs;
};
TF_Graph 是一个struct ,核心属性:
tensorflow::Graph graph
tensorflow::ShapeRefiner refiner
graph
首先看graph 的初始化,tensorflow::Graph是一个类,定义和初始化的方法都在tensorflow/core/graph/graph.cc 和 tensorflow/core/graph/graph.h
graph的主要属性和函数包括:
class Graph {
public:
// Constructs a graph with a single SOURCE (always id kSourceId) and a
// single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
//
// The graph can hold ops found in the registry. `ops`s lifetime must be at
// least that of the constructed graph's.
explicit Graph(const OpRegistryInterface* ops);
// Constructs a graph with a single SOURCE (always id kSourceId) and a
// single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
//
// The graph can hold ops found in `flib_def`. Unlike the constructor taking
// an OpRegistryInterface, this constructor copies the function definitions in
// `flib_def` so its lifetime may be shorter than that of the graph's. The
// OpRegistryInterface backing `flib_def` must still have the lifetime of the
// graph though.
explicit Graph(const FunctionLibraryDefinition& flib_def);
~Graph();
// Clone the current graph into a new one.
std::unique_ptr<Graph> Clone();
static const int kControlSlot;
// The GraphDef version range of this graph (see graph.proto).
const VersionDef& versions() const;
void set_versions(const VersionDef& versions);
// Adds a new node to this graph, and returns it. Infers the Op and
// input/output types for the node. *this owns the returned instance.
// Returns nullptr and sets *status on error.
Node* AddNode(NodeDef node_def, Status* status);
// Same as above, but using StatusOr. This method is always preferred.
StatusOr<Node*> AddNode(NodeDef node_def);
// Copies *node, which may belong to another graph, to a new node,
// which is returned. Does not copy any edges. *this owns the
// returned instance.
Node* CopyNode(const Node* node);
// Removes a node from this graph, including all edges from or to it.
// *node should not be accessed after calling this function.
// REQUIRES: node->IsOp()
void RemoveNode(Node* node);
void Copy(const Graph& src);
// Removes all nodes from this graph, including all edges from or to them.
// No Node* references to the Graph are valid post.
void Clear();
// Adds an edge that connects the xth output of `source` to the yth input of
// `dest` and returns it. Does not update dest's NodeDef.
const Edge* AddEdge(Node* source, int x, Node* dest, int y);
// Adds a control edge (no data flows along this edge) that connects `source`
// to `dest`. If `dest`s NodeDef is missing the corresponding control input,
// adds the control input.
//
// If such a control edge already exists and `allow_duplicates` is false, no
// edge is added and the function returns nullptr. Otherwise the edge is
// unconditionally created and returned. The NodeDef is not updated if
// `allow_duplicates` is true.
// TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
// graph_partition.cc. Figure out if we can do away with it.
const Edge* AddControlEdge(Node* source, Node* dest,
bool allow_duplicates = false);
// Removes edge from the graph. Does not update the destination node's
// NodeDef.
// REQUIRES: The edge must exist.
void RemoveEdge(const Edge* edge);
// Removes control edge `edge` from the graph. Note that this also updates
// the corresponding NodeDef to reflect the change.
// REQUIRES: The control edge must exist.
void RemoveControlEdge(const Edge* e);
// Updates the input to a node. The existing edge to `dst` is removed and an
// edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
// is also updated.
Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
// Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
// "While" op during gradient construction, see AddInputWhileHack in
// python_api.h for more details.
Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
// Adds the function and gradient definitions in `fdef_lib` to this graph's op
// registry. Ignores duplicate functions, and returns a bad status if an
// imported function differs from an existing function or op with the same
// name.
Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
// The number of live nodes in the graph.
//
// Because nodes can be removed from the graph, num_nodes() is often
// smaller than num_node_ids(). If one needs to create an array of
// nodes indexed by node ids, num_node_ids() should be used as the
// array's size.
int num_nodes() const { return num_nodes_; }
// The number of live nodes in the graph, excluding the Source and Sink nodes.
int num_op_nodes() const {
DCHECK_GE(num_nodes_, 2);
return num_nodes_ - 2;
}
// The number of live edges in the graph.
//
// Because edges can be removed from the graph, num_edges() is often
// smaller than num_edge_ids(). If one needs to create an array of
// edges indexed by edge ids, num_edge_ids() should be used as the
// array's size.
int num_edges() const { return num_edges_; }
// Serialize the nodes starting at `from_node_id` to a GraphDef.
void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
// Serialize to a GraphDef.
void ToGraphDef(GraphDef* graph_def) const;
// This version can be called from debugger to inspect the graph content.
// Use the previous version outside debug context for efficiency reasons.
//
// Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
// not defined in some TensorFlow builds.
GraphDef ToGraphDefDebug() const;
// Generate new node name with the specified prefix that is unique
// across this graph.
std::string NewName(StringPiece prefix);
// Access to the list of all nodes. Example usage:
// for (Node* node : graph.nodes()) { ... }
gtl::iterator_range<NodeIter> nodes() const;
// Access to the list of all nodes, excluding the Source and Sink nodes.
gtl::iterator_range<NodeIter> op_nodes() const;
// Returns one more than the maximum id assigned to any node.
int num_node_ids() const { return nodes_.size(); }
// Returns the node associated with an id, or nullptr if no node
// with that id (the node with that id was removed and the id has
// not yet been re-used). *this owns the returned instance.
// REQUIRES: 0 <= id < num_node_ids().
Node* FindNodeId(int id) const { return nodes_[id]; }
// Returns one more than the maximum id assigned to any edge.
int num_edge_ids() const { return edges_.size(); }
// Returns the Edge associated with an id, or nullptr if no edge
// with that id (the edge with that id was removed and the id has
// not yet been re-used). *this owns the returned instance.
// REQUIRES: 0 <= id < num_edge_ids().
const Edge* FindEdgeId(int id) const { return edges_[id]; }
// Access to the set of all edges. Example usage:
// for (const Edge* e : graph.edges()) { ... }
GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
// The pre-defined nodes.
enum { kSourceId = 0, kSinkId = 1 };
Node* source_node() const { return FindNodeId(kSourceId); }
Node* sink_node() const { return FindNodeId(kSinkId); }
const OpRegistryInterface* op_registry() const { return &ops_; }
const FunctionLibraryDefinition& flib_def() const { return ops_; }
// TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
FunctionLibraryDefinition* mutable_flib_def() { return &ops_; }
void CheckDeviceNameIndex(int index) {
DCHECK_GE(index, 0);
DCHECK_LT(index, static_cast<int>(device_names_.size()));
}
int InternDeviceName(const std::string& device_name);
const std::string& get_assigned_device_name(const Node& node) const {
return device_names_[node.assigned_device_name_index()];
}
void set_assigned_device_name_index(Node* node, int device_name_index) {
CheckDeviceNameIndex(device_name_index);
node->assigned_device_name_index_ = device_name_index;
}
void set_assigned_device_name(Node* node, const std::string& device_name) {
node->assigned_device_name_index_ = InternDeviceName(device_name);
}
// Returns OK if `node` is non-null and belongs to this graph
Status IsValidNode(const Node* node) const;
// Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not
// accept control outputs.
Status IsValidOutputTensor(const Node* node, int idx) const;
// Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept
// control inputs.
Status IsValidInputTensor(const Node* node, int idx) const;
// Create and return a new WhileContext owned by this graph. This is called
// when a new while loop is created. `frame_name` must be unique among
// WhileContexts in this graph.
Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
std::vector<Node*> exit_nodes,
OutputTensor cond_output,
std::vector<OutputTensor> body_inputs,
std::vector<OutputTensor> body_outputs,
WhileContext** result);
// Builds a node name to node pointer index for all nodes in the graph.
std::unordered_map<string, Node*> BuildNodeNameIndex() const;
absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const {
return const_arg_indices_cache_;
}
// TODO(kkb): Add to the constructor when it becomes managable.
// Sets the graph construction context.
void SetConstructionContext(ConstructionContext construction_context) {
construction_context_ = construction_context;
}
// TODO(kkb): Rename to `GetConstructionContext` once we're comfortable
// making this stable and make it available widely.
// Returns the graph construction context. It's `kUnknown` if not set.
ConstructionContext GetConstructionContextInternal() const {
return construction_context_;
}
// TODO(josh11b): uint64 hash() const;
private:
// If cost_node is non-null, then cost accounting (in CostModel)
// will be associated with that node rather than the new one being
// created.
//
// Ownership of the returned Node is not transferred to caller.
Node* AllocateNode(std::shared_ptr<NodeProperties> props,
const Node* cost_node, Node::NodeClass node_class);
void ReleaseNode(Node* node);
// Insert edge in free_edges_ for possible reuse.
void RecycleEdge(const Edge* edge);
// Registry of all known ops, including functions.
FunctionLibraryDefinition ops_;
// GraphDef versions
const std::unique_ptr<VersionDef> versions_;
// Allocator which will give us good locality.
core::Arena arena_;
// Map from node ids to allocated nodes. nodes_[id] may be nullptr if
// the node with that id was removed from the graph.
std::vector<Node*> nodes_;
// Number of nodes alive.
int64_t num_nodes_ = 0;
// Map from edge ids to allocated edges. edges_[id] may be nullptr if
// the edge with that id was removed from the graph.
std::vector<Edge*> edges_;
// The number of entries in edges_ that are not nullptr.
int num_edges_ = 0;
// Allocated but free nodes and edges.
std::vector<Node*> free_nodes_;
std::vector<Edge*> free_edges_;
// For generating unique names.
int name_counter_ = 0;
// In most graphs, the number of unique values used for the
// Node::assigned_device_name() property is quite small. If the graph is
// large, then this duplication of values can consume a significant amount of
// memory. Instead, we represent the same information using an interning
// table, which consists of a vector of unique strings (device_names_), as
// well a map (device_names_map_) from unique strings to indices within the
// unique string table.
//
// The InternDeviceName() method handles adding a new entry into the table,
// or locating the index of an existing entry.
//
// The fact that Node::assigned_device_name() is implemented using an
// interning table is intentionally public. This allows algorithms that
// frequently access this field to do so efficiently, especially for the case
// where the assigned_device_name of one Node is copied directly from that
// of another Node.
// A table of the unique assigned device names. Indices do NOT correspond
// to node IDs. Index 0 is always the empty string.
std::vector<string> device_names_;
// Maps unique device names to indices within device_names_[i].
std::unordered_map<string, int> device_names_map_;
// All the while contexts owned by this graph, keyed by frame name,
// corresponding to all the while loops contained in this graph (including
// nested loops). The stored contexts are usually accessed via
// AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
std::map<string, WhileContext> while_ctxs_;
// Cache of the indices of the arguments which need to be constant for the XLA
// compilation.
mutable absl::optional<std::vector<bool>> const_arg_indices_cache_;
// Indicates the context that this Graph instance is constructed.
ConstructionContext construction_context_ = ConstructionContext::kNotTracked;
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
};
其中核心属性为
const std::unique_ptr<VersionDef> versions_;
core::Arena arena_;
std::vector<Node*> nodes_;
FunctionLibraryDefinition ops_;
std::vector<Edge*> edges_;
核心的函数为
Node* AddNode(NodeDef node_def, Status* status); void RemoveNode(Node* node); const Edge* AddEdge(Node* source, int x, Node* dest, int y); const Edge* AddControlEdge(Node* source, Node* dest, bool allow_duplicates = false); Node* AllocateNode(std::shared_ptr<NodeProperties> props, const Node* cost_node, Node::NodeClass node_class);
Graph的构造函数为:
Graph::Graph(const OpRegistryInterface* ops)
: ops_(ops, FunctionDefLibrary()),
versions_(new VersionDef),
arena_(8 << 10 /* 8kB */) {
versions_->set_producer(TF_GRAPH_DEF_VERSION);
versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
// Initialize the name interning table for assigned_device_name.
device_names_.push_back("");
DCHECK_EQ(0, InternDeviceName(""));
// Source and sink have no endpoints, just control edges.
NodeDef def;
def.set_name("_SOURCE");
def.set_op("NoOp");
Status status;
Node* source = AddNode(def, &status);
TF_CHECK_OK(status);
CHECK_EQ(source->id(), kSourceId);
def.set_name("_SINK");
Node* sink = AddNode(def, &status);
TF_CHECK_OK(status);
CHECK_EQ(sink->id(), kSinkId);
AddControlEdge(source, sink);
}
构造函数接受一个OpRegistryInterface 对象, 在tensorflow的op注册源码精读_kangshuangzhu的博客-CSDN博客我们已经介绍了OpRegistry是OpRegistryInterface的子类,所以一般这里会传入OpRegistry。 在最一开始的地方就是这样的
graph(tensorflow::OpRegistry::Global())
同样在OpRegistry 中已经介绍了,global方法会返回一个全局唯一的OpRegistry。
构造函数中,初始化了三个属性
ops_
versions_
arena_
ops_在graph中的定义为
FunctionLibraryDefinition ops_;
FunctionLibraryDefinition和 OpRegistry 都是继承自OpRegistryInterface,作用类似。graph还有一个构造函数就接受FunctionLibraryDefinition 作为入参:
Graph::Graph(const FunctionLibraryDefinition& flib_def)
: Graph(flib_def.default_registry()) {
// Need a new-enough consumer to support the functions we add to the graph.
if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
versions_->set_min_consumer(12);
}
Status s = ops_.AddLibrary(flib_def);
CHECK(s.ok()) << s.error_message();
}
同时因为一个图必须要有开始和结束节点,所以通过addNode方法添加了名为"_SOURCE" 和 "_SINK" 的两个节点。
addNode的源码如下:
Node* Graph::AddNode(NodeDef node_def, Status* status) {
const OpRegistrationData* op_reg_data;
status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
if (!status->ok()) return nullptr;
DataTypeVector inputs;
DataTypeVector outputs;
status->Update(
InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
if (!status->ok()) {
*status = AttachDef(*status, node_def);
return nullptr;
}
Node::NodeClass node_class = op_reg_data->is_function_op
? Node::NC_FUNCTION_OP
: Node::GetNodeClassForOp(node_def.op());
if (node_def.has_experimental_type()) {
VLOG(3) << "AddNode: node has type set, skipping type constructor "
<< node_def.name();
} else {
if (op_reg_data->type_ctor != nullptr) {
VLOG(3) << "AddNode: found type constructor for " << node_def.name();
Status s =
full_type::SpecializeType(AttrSlice(node_def), op_reg_data->op_def,
*(node_def.mutable_experimental_type()));
if (!s.ok()) {
*status = errors::InvalidArgument("type error: ", s.ToString());
VLOG(3) << "AddNode: type inference failed for " << node_def.name()
<< ": " << s;
return nullptr;
}
} else {
VLOG(3) << "AddNode: no type constructor for " << node_def.name();
}
}
Node* node = AllocateNode(std::make_shared<NodeProperties>(
&op_reg_data->op_def, std::move(node_def),
inputs, outputs, op_reg_data->fwd_type_fn),
nullptr, node_class);
return node;
}
Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
const Node* cost_node, Node::NodeClass node_class) {
Node* node = nullptr;
if (free_nodes_.empty()) {
node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
} else {
node = free_nodes_.back();
free_nodes_.pop_back();
}
node->graph_ = this;
const int id = nodes_.size();
int cost_id = cost_node ? cost_node->cost_id() : id;
node->Initialize(id, cost_id, std::move(props), node_class);
nodes_.push_back(node);
++num_nodes_;
return node;
}
大概过程是:
1. 首先查找该nodedef 是否注册,如果已被注册,则将注册信息取出,赋给一个空OpRegistrationData
2. 获取op的nodeclass,这个结果在后面的环节中要用到,nodeclass是一个枚举值在tensorflow 之 Node, NodeDef, NodeProperties
中有介绍
3. 通过调用AllocateNode添加一个新的node, AllocateNode的入参主要一个NodeProperties 和 node_class。 NodeProperties同样在tensorflow 之 Node, NodeDef, NodeProperties
有介绍。AllocateNode添加node的时候,首先创建一个空的node,开辟一个新的node大小的内存空间,或者利用free_nodes_的内存,free_nodes_是那些被释放清空的node。然后给这个node设置id,graph,cost_id等属性。设置属性以后通过Initialize初始化这个node, 最后把node压进graph的 nodes_ 中,并且num_nodes_计数加1.
可以看到,AllocateNode只是把node加入到graph的nodes_中,而并没有添加相应的edges