// Thread compatible but not thread safe. class Graph { public: explicit Graph(const OpRegistryInterface* registry); explicit Graph(const FunctionLibraryDefinition& flib_def); ~Graph(); static const int kControlSlot; const VersionDef& versions() const; void set_versions(const VersionDef& versions); Node* AddNode(const NodeDef& node_def, Status* status); Node* CopyNode(Node* node); void RemoveNode(Node* node); const Edge* AddEdge(Node* source, int x, Node* dest, int y); const Edge* AddControlEdge(Node* source, Node* dest) { return AddEdge(source, kControlSlot, dest, kControlSlot); } void RemoveEdge(const Edge* edge); Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib); int num_nodes() const { return num_nodes_; } int num_op_nodes() const { DCHECK_GE(num_nodes_, 2); return num_nodes_ - 2; }
int num_edges() const { return num_edges_; } void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const; void ToGraphDef(GraphDef* graph_def) const; string NewName(StringPiece prefix);
gtl::iterator_range<NodeIter> nodes() const; gtl::iterator_range<NodeIter> op_nodes() const; int num_node_ids() const { return nodes_.size(); } Node* FindNodeId(int id) const { return nodes_[id]; }
int num_edge_ids() const { return edges_.size(); } const Edge* FindEdgeId(int id) const { return edges_[id]; } GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); } enum { kSourceId = 0, kSinkId = 1 }; Node* source_node() const { return FindNodeId(kSourceId); } Node* sink_node() const { return FindNodeId(kSinkId); }
const OpRegistryInterface* op_registry() const { return &ops_; } const FunctionLibraryDefinition& flib_def() const { return ops_; }
void CheckDeviceNameIndex(int index) { DCHECK_GE(index, 0); DCHECK_LT(index, static_cast<int>(device_names_.size())); }
int InternDeviceName(const string& device_name);
const string& get_assigned_device_name(const Node& node) const { return device_names_[node.assigned_device_name_index()]; }
void set_assigned_device_name_index(Node* node, int device_name_index) { CheckDeviceNameIndex(device_name_index); node->assigned_device_name_index_ = device_name_index; }
void set_assigned_device_name(Node* node, const string& device_name) { node->assigned_device_name_index_ = InternDeviceName(device_name); }
Status IsValidNode(const Node* node) const; Status IsValidOutputTensor(const Node* node, int idx) const;
private: Node* AllocateNode(std::shared_ptr<NodeProperties> props, const Node* cost_node); void ReleaseNode(Node* node); FunctionLibraryDefinition ops_; const std::unique_ptr<VersionDef> versions_; core::Arena arena_; std::vector<Node*> nodes_;
int64 num_nodes_ = 0; std::vector<Edge*> edges_; int num_edges_ = 0; std::vector<Node*> free_nodes_; std::vector<Edge*> free_edges_;
int name_counter_ = 0; std::vector<string> device_names_; std::unordered_map<string, int> device_names_map_; TF_DISALLOW_COPY_AND_ASSIGN(Graph); };