tensorflow 之 graph

本节介绍tensorflow中的graph,在c_api.cc 中有创建graph的例子,可以从这个为切入点,探索graph的使用。

TF_Graph

在c_api.cc中,创建graph的代码如下:

TF_Graph* TF_NewGraph() { return new TF_Graph; }

TF_Graph::TF_Graph()
    : graph(tensorflow::OpRegistry::Global()),
      refiner(graph.versions().producer(), graph.op_registry()),
      delete_requested(false),
      parent(nullptr),
      parent_inputs(nullptr) {
  // Tell the shape refiner to also run shape inference on functions.
  refiner.set_function_library_for_shape_inference(&graph.flib_def());
}

通过调用TF_Graph来创建一个图,TF_Graph的定义在tensorflow/c/c_api_internal.h

struct TF_Graph {
  TF_Graph();

  mutable tensorflow::mutex mu;
  tensorflow::Graph graph TF_GUARDED_BY(mu);

  // Runs shape inference.
  tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);

  // Maps from name of an operation to the Node* in 'graph'.
  std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
      TF_GUARDED_BY(mu);

  // The keys of this map are all the active sessions using this graph. Each
  // value records whether the graph has been mutated since the corresponding
  // session has been run (this is detected in RecordMutation function). If the
  // string is empty, no mutation has occurred. Otherwise the string is a
  // description of the mutation suitable for returning to the user.
  //
  // Sessions are added to this map in TF_NewSession, and removed in
  // TF_DeleteSession.
  // TF_Graph may only / must be deleted when
  //   sessions.size() == 0 && delete_requested == true
  //
  // TODO(b/74949947): mutations currently trigger a warning instead of a bad
  // status, this should be reverted when possible.
  tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
      TF_GUARDED_BY(mu);
  bool delete_requested TF_GUARDED_BY(mu);  // set true by TF_DeleteGraph

  // Used to link graphs contained in TF_WhileParams to the parent graph that
  // will eventually contain the full while loop.
  TF_Graph* parent;
  TF_Output* parent_inputs;
};

TF_Graph 是一个struct ,核心属性:

tensorflow::Graph graph

tensorflow::ShapeRefiner refiner

graph 

 首先看graph 的初始化,tensorflow::Graph是一个类,定义和初始化的方法都在tensorflow/core/graph/graph.cc 和 tensorflow/core/graph/graph.h

graph的主要属性和函数包括:

class Graph {
 public:
  // Constructs a graph with a single SOURCE (always id kSourceId) and a
  // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
  //
  // The graph can hold ops found in the registry. `ops`s lifetime must be at
  // least that of the constructed graph's.
  explicit Graph(const OpRegistryInterface* ops);

  // Constructs a graph with a single SOURCE (always id kSourceId) and a
  // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
  //
  // The graph can hold ops found in `flib_def`. Unlike the constructor taking
  // an OpRegistryInterface, this constructor copies the function definitions in
  // `flib_def` so its lifetime may be shorter than that of the graph's. The
  // OpRegistryInterface backing `flib_def` must still have the lifetime of the
  // graph though.
  explicit Graph(const FunctionLibraryDefinition& flib_def);

  ~Graph();

  // Clone the current graph into a new one.
  std::unique_ptr<Graph> Clone();

  static const int kControlSlot;

  // The GraphDef version range of this graph (see graph.proto).
  const VersionDef& versions() const;
  void set_versions(const VersionDef& versions);

  // Adds a new node to this graph, and returns it. Infers the Op and
  // input/output types for the node. *this owns the returned instance.
  // Returns nullptr and sets *status on error.
  Node* AddNode(NodeDef node_def, Status* status);

  // Same as above, but using StatusOr. This method is always preferred.
  StatusOr<Node*> AddNode(NodeDef node_def);

  // Copies *node, which may belong to another graph, to a new node,
  // which is returned.  Does not copy any edges.  *this owns the
  // returned instance.
  Node* CopyNode(const Node* node);

  // Removes a node from this graph, including all edges from or to it.
  // *node should not be accessed after calling this function.
  // REQUIRES: node->IsOp()
  void RemoveNode(Node* node);

  void Copy(const Graph& src);

  // Removes all nodes from this graph, including all edges from or to them.
  // No Node* references to the Graph are valid post.
  void Clear();

  // Adds an edge that connects the xth output of `source` to the yth input of
  // `dest` and returns it. Does not update dest's NodeDef.
  const Edge* AddEdge(Node* source, int x, Node* dest, int y);

  // Adds a control edge (no data flows along this edge) that connects `source`
  // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
  // adds the control input.
  //
  // If such a control edge already exists and `allow_duplicates` is false, no
  // edge is added and the function returns nullptr. Otherwise the edge is
  // unconditionally created and returned. The NodeDef is not updated if
  // `allow_duplicates` is true.
  // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
  // graph_partition.cc. Figure out if we can do away with it.
  const Edge* AddControlEdge(Node* source, Node* dest,
                             bool allow_duplicates = false);

  // Removes edge from the graph. Does not update the destination node's
  // NodeDef.
  // REQUIRES: The edge must exist.
  void RemoveEdge(const Edge* edge);

  // Removes control edge `edge` from the graph. Note that this also updates
  // the corresponding NodeDef to reflect the change.
  // REQUIRES: The control edge must exist.
  void RemoveControlEdge(const Edge* e);

  // Updates the input to a node.  The existing edge to `dst` is removed and an
  // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
  // is also updated.
  Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);

  // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
  // "While" op during gradient construction, see AddInputWhileHack in
  // python_api.h for more details.
  Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);

  // Adds the function and gradient definitions in `fdef_lib` to this graph's op
  // registry. Ignores duplicate functions, and returns a bad status if an
  // imported function differs from an existing function or op with the same
  // name.
  Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);

  // The number of live nodes in the graph.
  //
  // Because nodes can be removed from the graph, num_nodes() is often
  // smaller than num_node_ids(). If one needs to create an array of
  // nodes indexed by node ids, num_node_ids() should be used as the
  // array's size.
  int num_nodes() const { return num_nodes_; }

  // The number of live nodes in the graph, excluding the Source and Sink nodes.
  int num_op_nodes() const {
    DCHECK_GE(num_nodes_, 2);
    return num_nodes_ - 2;
  }

  // The number of live edges in the graph.
  //
  // Because edges can be removed from the graph, num_edges() is often
  // smaller than num_edge_ids(). If one needs to create an array of
  // edges indexed by edge ids, num_edge_ids() should be used as the
  // array's size.
  int num_edges() const { return num_edges_; }

  // Serialize the nodes starting at `from_node_id` to a GraphDef.
  void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;

  // Serialize to a GraphDef.
  void ToGraphDef(GraphDef* graph_def) const;

  // This version can be called from debugger to inspect the graph content.
  // Use the previous version outside debug context for efficiency reasons.
  //
  // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
  // not defined in some TensorFlow builds.
  GraphDef ToGraphDefDebug() const;

  // Generate new node name with the specified prefix that is unique
  // across this graph.
  std::string NewName(StringPiece prefix);

  // Access to the list of all nodes.  Example usage:
  //   for (Node* node : graph.nodes()) { ... }
  gtl::iterator_range<NodeIter> nodes() const;

  // Access to the list of all nodes, excluding the Source and Sink nodes.
  gtl::iterator_range<NodeIter> op_nodes() const;

  // Returns one more than the maximum id assigned to any node.
  int num_node_ids() const { return nodes_.size(); }

  // Returns the node associated with an id, or nullptr if no node
  // with that id (the node with that id was removed and the id has
  // not yet been re-used). *this owns the returned instance.
  // REQUIRES: 0 <= id < num_node_ids().
  Node* FindNodeId(int id) const { return nodes_[id]; }

  // Returns one more than the maximum id assigned to any edge.
  int num_edge_ids() const { return edges_.size(); }

  // Returns the Edge associated with an id, or nullptr if no edge
  // with that id (the edge with that id was removed and the id has
  // not yet been re-used). *this owns the returned instance.
  // REQUIRES: 0 <= id < num_edge_ids().
  const Edge* FindEdgeId(int id) const { return edges_[id]; }

  // Access to the set of all edges.  Example usage:
  //   for (const Edge* e : graph.edges()) { ... }
  GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }

  // The pre-defined nodes.
  enum { kSourceId = 0, kSinkId = 1 };
  Node* source_node() const { return FindNodeId(kSourceId); }
  Node* sink_node() const { return FindNodeId(kSinkId); }

  const OpRegistryInterface* op_registry() const { return &ops_; }
  const FunctionLibraryDefinition& flib_def() const { return ops_; }

  // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
  FunctionLibraryDefinition* mutable_flib_def() { return &ops_; }

  void CheckDeviceNameIndex(int index) {
    DCHECK_GE(index, 0);
    DCHECK_LT(index, static_cast<int>(device_names_.size()));
  }

  int InternDeviceName(const std::string& device_name);

  const std::string& get_assigned_device_name(const Node& node) const {
    return device_names_[node.assigned_device_name_index()];
  }

  void set_assigned_device_name_index(Node* node, int device_name_index) {
    CheckDeviceNameIndex(device_name_index);
    node->assigned_device_name_index_ = device_name_index;
  }

  void set_assigned_device_name(Node* node, const std::string& device_name) {
    node->assigned_device_name_index_ = InternDeviceName(device_name);
  }

  // Returns OK if `node` is non-null and belongs to this graph
  Status IsValidNode(const Node* node) const;

  // Returns OK if IsValidNode(`node`) and `idx` is a valid output.  Does not
  // accept control outputs.
  Status IsValidOutputTensor(const Node* node, int idx) const;

  // Returns OK if IsValidNode(`node`) and `idx` a valid input.  Does not accept
  // control inputs.
  Status IsValidInputTensor(const Node* node, int idx) const;

  // Create and return a new WhileContext owned by this graph. This is called
  // when a new while loop is created. `frame_name` must be unique among
  // WhileContexts in this graph.
  Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
                         std::vector<Node*> exit_nodes,
                         OutputTensor cond_output,
                         std::vector<OutputTensor> body_inputs,
                         std::vector<OutputTensor> body_outputs,
                         WhileContext** result);

  // Builds a node name to node pointer index for all nodes in the graph.
  std::unordered_map<string, Node*> BuildNodeNameIndex() const;

  absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const {
    return const_arg_indices_cache_;
  }

  // TODO(kkb): Add to the constructor when it becomes managable.
  // Sets the graph construction context.
  void SetConstructionContext(ConstructionContext construction_context) {
    construction_context_ = construction_context;
  }

  // TODO(kkb): Rename to `GetConstructionContext` once we're comfortable
  // making this stable and make it available widely.
  // Returns the graph construction context. It's `kUnknown` if not set.
  ConstructionContext GetConstructionContextInternal() const {
    return construction_context_;
  }

  // TODO(josh11b): uint64 hash() const;

 private:
  // If cost_node is non-null, then cost accounting (in CostModel)
  // will be associated with that node rather than the new one being
  // created.
  //
  // Ownership of the returned Node is not transferred to caller.
  Node* AllocateNode(std::shared_ptr<NodeProperties> props,
                     const Node* cost_node, Node::NodeClass node_class);
  void ReleaseNode(Node* node);
  // Insert edge in free_edges_ for possible reuse.
  void RecycleEdge(const Edge* edge);
  // Registry of all known ops, including functions.
  FunctionLibraryDefinition ops_;

  // GraphDef versions
  const std::unique_ptr<VersionDef> versions_;

  // Allocator which will give us good locality.
  core::Arena arena_;

  // Map from node ids to allocated nodes.  nodes_[id] may be nullptr if
  // the node with that id was removed from the graph.
  std::vector<Node*> nodes_;

  // Number of nodes alive.
  int64_t num_nodes_ = 0;

  // Map from edge ids to allocated edges.  edges_[id] may be nullptr if
  // the edge with that id was removed from the graph.
  std::vector<Edge*> edges_;

  // The number of entries in edges_ that are not nullptr.
  int num_edges_ = 0;

  // Allocated but free nodes and edges.
  std::vector<Node*> free_nodes_;
  std::vector<Edge*> free_edges_;

  // For generating unique names.
  int name_counter_ = 0;

  // In most graphs, the number of unique values used for the
  // Node::assigned_device_name() property is quite small.  If the graph is
  // large, then this duplication of values can consume a significant amount of
  // memory.  Instead, we represent the same information using an interning
  // table, which consists of a vector of unique strings (device_names_), as
  // well a map (device_names_map_) from unique strings to indices within the
  // unique string table.
  //
  // The InternDeviceName() method handles adding a new entry into the table,
  // or locating the index of an existing entry.
  //
  // The fact that Node::assigned_device_name() is implemented using an
  // interning table is intentionally public.  This allows algorithms that
  // frequently access this field to do so efficiently, especially for the case
  // where the assigned_device_name of one Node is copied directly from that
  // of another Node.

  // A table of the unique assigned device names.  Indices do NOT correspond
  // to node IDs.  Index 0 is always the empty string.
  std::vector<string> device_names_;

  // Maps unique device names to indices within device_names_[i].
  std::unordered_map<string, int> device_names_map_;

  // All the while contexts owned by this graph, keyed by frame name,
  // corresponding to all the while loops contained in this graph (including
  // nested loops). The stored contexts are usually accessed via
  // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
  std::map<string, WhileContext> while_ctxs_;

  // Cache of the indices of the arguments which need to be constant for the XLA
  // compilation.
  mutable absl::optional<std::vector<bool>> const_arg_indices_cache_;

  // Indicates the context that this Graph instance is constructed.
  ConstructionContext construction_context_ = ConstructionContext::kNotTracked;

  TF_DISALLOW_COPY_AND_ASSIGN(Graph);
};

其中核心属性为

const std::unique_ptr<VersionDef> versions_;

core::Arena arena_;

std::vector<Node*> nodes_;

FunctionLibraryDefinition ops_;

std::vector<Edge*> edges_;

核心的函数为 

Node* AddNode(NodeDef node_def, Status* status);

void RemoveNode(Node* node);

const Edge* AddEdge(Node* source, int x, Node* dest, int y);

const Edge* AddControlEdge(Node* source, Node* dest,

bool allow_duplicates = false);

Node* AllocateNode(std::shared_ptr<NodeProperties> props,

const Node* cost_node, Node::NodeClass node_class);

 Graph的构造函数为:

Graph::Graph(const OpRegistryInterface* ops)
    : ops_(ops, FunctionDefLibrary()),
      versions_(new VersionDef),
      arena_(8 << 10 /* 8kB */) {
  versions_->set_producer(TF_GRAPH_DEF_VERSION);
  versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);

  // Initialize the name interning table for assigned_device_name.
  device_names_.push_back("");
  DCHECK_EQ(0, InternDeviceName(""));

  // Source and sink have no endpoints, just control edges.
  NodeDef def;
  def.set_name("_SOURCE");
  def.set_op("NoOp");
  Status status;
  Node* source = AddNode(def, &status);
  TF_CHECK_OK(status);
  CHECK_EQ(source->id(), kSourceId);

  def.set_name("_SINK");
  Node* sink = AddNode(def, &status);
  TF_CHECK_OK(status);
  CHECK_EQ(sink->id(), kSinkId);

  AddControlEdge(source, sink);
}

 构造函数接受一个OpRegistryInterface  对象, 在tensorflow的op注册源码精读_kangshuangzhu的博客-CSDN博客我们已经介绍了OpRegistry是OpRegistryInterface的子类,所以一般这里会传入OpRegistry。 在最一开始的地方就是这样的

graph(tensorflow::OpRegistry::Global())

 同样在OpRegistry 中已经介绍了,global方法会返回一个全局唯一的OpRegistry。

构造函数中,初始化了三个属性

ops_

versions_

arena_

ops_在graph中的定义为

FunctionLibraryDefinition ops_;

FunctionLibraryDefinition和 OpRegistry 都是继承自OpRegistryInterface,作用类似。graph还有一个构造函数就接受FunctionLibraryDefinition 作为入参:

Graph::Graph(const FunctionLibraryDefinition& flib_def)
    : Graph(flib_def.default_registry()) {
  // Need a new-enough consumer to support the functions we add to the graph.
  if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
    versions_->set_min_consumer(12);
  }
  Status s = ops_.AddLibrary(flib_def);
  CHECK(s.ok()) << s.error_message();
}

同时因为一个图必须要有开始和结束节点,所以通过addNode方法添加了名为"_SOURCE" 和 "_SINK" 的两个节点。

addNode的源码如下:

Node* Graph::AddNode(NodeDef node_def, Status* status) {
  const OpRegistrationData* op_reg_data;
  status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
  if (!status->ok()) return nullptr;

  DataTypeVector inputs;
  DataTypeVector outputs;
  status->Update(
      InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
  if (!status->ok()) {
    *status = AttachDef(*status, node_def);
    return nullptr;
  }

  Node::NodeClass node_class = op_reg_data->is_function_op
                                   ? Node::NC_FUNCTION_OP
                                   : Node::GetNodeClassForOp(node_def.op());

  if (node_def.has_experimental_type()) {
    VLOG(3) << "AddNode: node has type set, skipping type constructor "
            << node_def.name();
  } else {
    if (op_reg_data->type_ctor != nullptr) {
      VLOG(3) << "AddNode: found type constructor for " << node_def.name();
      Status s =
          full_type::SpecializeType(AttrSlice(node_def), op_reg_data->op_def,
                                    *(node_def.mutable_experimental_type()));
      if (!s.ok()) {
        *status = errors::InvalidArgument("type error: ", s.ToString());
        VLOG(3) << "AddNode: type inference failed for " << node_def.name()
                << ": " << s;
        return nullptr;
      }
    } else {
      VLOG(3) << "AddNode: no type constructor for " << node_def.name();
    }
  }

  Node* node = AllocateNode(std::make_shared<NodeProperties>(
                                &op_reg_data->op_def, std::move(node_def),
                                inputs, outputs, op_reg_data->fwd_type_fn),
                            nullptr, node_class);
  return node;
}


Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
                          const Node* cost_node, Node::NodeClass node_class) {
  Node* node = nullptr;
  if (free_nodes_.empty()) {
    node = new (arena_.Alloc(sizeof(Node))) Node;  // placement new
  } else {
    node = free_nodes_.back();
    free_nodes_.pop_back();
  }
  node->graph_ = this;
  const int id = nodes_.size();
  int cost_id = cost_node ? cost_node->cost_id() : id;
  node->Initialize(id, cost_id, std::move(props), node_class);
  nodes_.push_back(node);
  ++num_nodes_;
  return node;
}

大概过程是:

1. 首先查找该nodedef 是否注册,如果已被注册,则将注册信息取出,赋给一个空OpRegistrationData

2. 获取op的nodeclass,这个结果在后面的环节中要用到,nodeclass是一个枚举值在tensorflow 之 Node, NodeDef, NodeProperties

中有介绍

3.  通过调用AllocateNode添加一个新的node, AllocateNode的入参主要一个NodeProperties  和 node_class。 NodeProperties同样在tensorflow 之 Node, NodeDef, NodeProperties

有介绍。AllocateNode添加node的时候,首先创建一个空的node,开辟一个新的node大小的内存空间,或者利用free_nodes_的内存,free_nodes_是那些被释放清空的node。然后给这个node设置id,graph,cost_id等属性。设置属性以后通过Initialize初始化这个node, 最后把node压进graph的 nodes_ 中,并且num_nodes_计数加1.

可以看到,AllocateNode只是把node加入到graph的nodes_中,而并没有添加相应的edges

FunctionLibraryDefinition

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值