simple_graph_execution_state

simple_graph_execution_state.h

#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_

#include <functional>
#include <memory>
#include <string>
#include <vector>

#include "tensorflow/core/common_runtime/build_graph_options.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {
struct SessionOptions;
class StepStats;
class Timeline;

namespace subgraph {
struct RewriteGraphMetadata;
}

struct SimpleGraphExecutionStateOptions {
  const DeviceSet* device_set = nullptr;
  const SessionOptions* session_options = nullptr;
  // A map from node name to device name, representing the unchangeable
  // placement of stateful nodes.
  // 从节点名称到设备名称的映射,表示状态节点不可更改的位置。
  std::unordered_map<string, string> stateful_placements;
};

// A SimpleClientGraph is simply a sub-graph of the full graph as induced by BuildGraphOptions.
// SimpleClientGraph 只是 BuildGraphOptions 引发的完整图形的子图。
struct SimpleClientGraph {
  explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
                             DataTypeVector feed_types,
                             DataTypeVector fetch_types)
      : flib_def(std::move(flib)),
        graph(flib_def.get()),
        feed_types(std::move(feed_types)),
        fetch_types(std::move(fetch_types)) {}
  // Each client-graph gets its own function library since optimization passes
  // post rewrite for execution might want to introduce new functions.
  // 每个客户端图获得自己的函数库,因为优化通过后重写以执行可能需要引入新的函数。
  std::unique_ptr<FunctionLibraryDefinition> flib_def;
  Graph graph;
  DataTypeVector feed_types;
  DataTypeVector fetch_types;
};

// SimpleGraphExecutionState is responsible for generating an
// executable SimpleClientGraph from the original GraphDef that specifies
// the complete graph and from BuildGraphOptions which specifies
// input/output nodes.
// SimpleGraphExecutionState 负责从原始的 GraphDef 中生成一个可执行的 SimpleClientGraph,
// 该 GraphDef 指定了完整的图形,并且指定了指定输入/输出节点的 BuildGraphOptions。
//
// An executable Graph differs from a GraphDef by being Placed,
// meaning that each Node is assigned to a single Device in the
// available set.
// 可执行图形与 GraphDef 的不同之处在于放置,这意味着每个节点都分配给可用集合中的单个设备。
//
// When SimpleGraphExecutionState is first constructed it instantiates
// a full Graph from the provided GraphDef, and places it, using only
// the static device assignments from the GraphDef.  Nodes without are
// currently placed in a very naive way.  Since stateful Nodes cannot
// be moved after initial placement, it is important that stateful
// Nodes get sensible initial device assignments in the graph
// definition.
// 当 SimpleGraphExecutionState 首次构建时,它将从提供的 GraphDef 中实例化一个完整的图形,并将其放置,
// 仅使用 GraphDef 中的静态设备分配。没有的节点目前处于非常幼稚的方式。
// 由于有状态节点在初始放置之后无法移动,因此状态很重要。节点在图形定义中获得明智的初始设备分配。
//
// Subsequently, SimpleGraphExecutionState generates a SimpleClientGraph on
// demand, which is a sub-graph of the latest placement of the full
// Graph.  MasterSession uses such a SimpleClientGraph to execute one or
// more similar client requests.
// 随后,SimpleGraphExecutionState 根据需要生成一个 SimpleClientGraph,它是完整图形的最新位置的子图。 
// MasterSession 使用这样的 SimpleClientGraph 来执行一个或多个类似的客户端请求。
//
// SimpleGraphExecutionState is thread-safe.
// SimpleGraphExecutionState 是线程安全的。
class SimpleGraphExecutionState {
 public:
  virtual ~SimpleGraphExecutionState();

  // Creates a new `SimpleGraphExecutionState` for the given
  // `graph_def`, which represents the entire graph for a session.
  // 为给定的`graph_def`创建一个新的`SimpleGraphExecutionState`,它表示会话的整个图形。
  
  // N.B. This method uses `GraphDef::Swap()` and leaves `graph_def`
  // in an undefined state. If it is necessary to use `*graph_def`
  // after this call, make an explicit copy of the graph before
  // calling this method.
  // 该方法使用`GraphDef::Swap()`,并将 `graph_def` 留在未定义的状态。 
  // 如果在此调用后需要使用`*graph_def`,则在调用此方法之前,请先创建图形的显式副本。
  static Status MakeForBaseGraph(
      GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options,
      std::unique_ptr<SimpleGraphExecutionState>* out_state);

  // Creates a new `SimpleGraphExecutionState` and `SimpleClientGraph`
  // for the subgraph of `original_graph_def` defined by `subgraph_options`.
  // 为`subgraph_options`定义的`original_graph_def`子图创建一个新的`SimpleGraphExecutionState`和`SimpleClientGraph`。
  static Status MakeForPrunedGraph(
      const FunctionDefLibrary& func_def_lib,
      const SimpleGraphExecutionStateOptions& options,
      const GraphDef& original_graph_def,
      const BuildGraphOptions& subgraph_options,
      std::unique_ptr<SimpleGraphExecutionState>* out_state,
      std::unique_ptr<SimpleClientGraph>* out_client_graph);

  // Creates a new SimpleGraphExecutionState representing the
  // concatenation of this graph, and the graph defined by
  // "extension_def". The same name may not be used to define a node
  // in both this graph and "extension_def".
  // 创建一个新的 SimpleGraphExecutionState,表示此图形的连接,以及由 "xtension_def" 定义的图形。
  // 同一个名称可能不用于在此图和"extension_def"中定义节点。
  //
  // If successful, returns OK and the caller takes ownership of "*out".
  // Otherwise returns an error and does not modify "*out".
  // 如果成功,返回 OK,调用者拥有"*out"的所有权。 否则返回错误,不会修改"*out"。
  //
  // After calling `old_state->Extend()`, `old_state` may no longer be used.
  // 调用 `old_state->Extend()`后,可能不再使用`old_state`。
  //
  // NOTE(mrry): This method respects the placement of stateful nodes in
  // in *this, but currently does not transfer any other placement
  // or cost model information to the new graph.
  // 此方法尊重在 *this 中的状态节点的位置,但是目前不会将任何其他位置或成本模型信息传输到新的图形。
  Status Extend(const GraphDef& extension_def,
                std::unique_ptr<SimpleGraphExecutionState>* out) const;

  // Builds a SimpleClientGraph (a sub-graph of the full graph as induced by
  // the Node set specified in "options").  If successful, returns OK
  // and the caller takes the ownership of "*out". Otherwise, returns an error.
  // 构建 SimpleClientGraph(由"options"中指定的 Node 集合引发的完整图形的子图)。 
  // 如果成功,返回OK,调用者拥有"*out"的所有权。 否则返回错误。
  Status BuildGraph(const BuildGraphOptions& options,
                    std::unique_ptr<SimpleClientGraph>* out);

  // The graph returned by BuildGraph may contain only the pruned
  // graph, whereas some clients may want access to the full graph.
  // BuildGraph 返回的图可能只包含已修剪的图形,而某些客户端可能希望访问完整的图形。
  const Graph* full_graph() {
    return graph_;
  }

  // Returns the node with the given name, or null if it does not exist.
  // 返回具有给定名称的节点,如果不存在则返回 null。
  const Node* get_node_by_name(const string& name) const {
    NodeNameToCostIdMap::const_iterator iter =
        node_name_to_cost_id_map_.find(name);
    if (iter != node_name_to_cost_id_map_.end()) {
      return graph_->FindNodeId(iter->second);
    } else {
      return nullptr;
    }
  }

  // Returns a reference to the current graph_def.  Use must
  // not extend beyond lifetime of SimpleGrahExecutionState object.
  // 返回对当前 graph_def 的引用。 使用不能超出 SimpleGrahExecutionState 对象的生命周期。
  const GraphDef& original_graph_def() { return original_graph_def_; }

  // Returns the map of stateful placements as a map of node name to placement string.
  // 将状态展示位置的地图作为节点名称的映射返回到展示位置字符串。
  std::unordered_map<string, string> GetStatefulPlacements() const {
    mutex_lock l(mu_);
    return stateful_placements_;
  }

 private:
  SimpleGraphExecutionState(GraphDef* graph_def,
                            const SimpleGraphExecutionStateOptions& options);

  Status InitBaseGraph(const BuildGraphOptions& options);

  // Map of placed stateful nodes, i.e. nodes for which is_stateful()
  // is true, such as "params" and "queue" nodes.  Once placed these
  // nodes can not be moved to a different device.  Maps node names to
  // device names.
  // 放置状态节点的映射,即 is_stateful() 为 true 的节点,如"params"和"queue"节点。
  // 一旦放置这些节点不能移动到不同的设备。 将节点名称映射到设备名称。
  std::unordered_map<string, string> stateful_placements_;  // Immutable after ctor.
                                                            
  void SaveStatefulNodes(Graph* graph);
  void RestoreStatefulNodes(Graph* graph);

  GraphDef original_graph_def_;            // Immutable after ctor.
  const DeviceSet* device_set_;            // Not owned
  const SessionOptions* session_options_;  // Not owned

  mutable mutex mu_;
  CostModel costs_ GUARDED_BY(mu_);

  // Map from name to Node for the full graph in placed_.
  NodeNameToCostIdMap node_name_to_cost_id_map_;

  // 'flib_def_' is initialized from the initial graph def's library,
  // and may be updated by a graph optimization pass.
  // 'flib_def_'初始化从初始图形 def 的库,可以通过图优化通过来更新。
  std::unique_ptr<FunctionLibraryDefinition> flib_def_;

  // `rewrite_metadata_` is only set for SimpleGraphExecutionState
  // objects created by `MakeForPrunedGraph()`.
  // `rewrite_metadata_`仅为由`MakeForPrunedGraph()`创建的 SimpleGraphExecutionState 对象设置。
  std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;

  // The dataflow graph owned by this object. 该对象拥有的数据流图。
  Graph* graph_;

  TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState);
};

}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_





simple_graph_execution_state.cc


#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"


#include <memory>
#include <string>
#include <unordered_set>
#include <vector>


#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/simple_placer.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"


#ifndef IS_MOBILE_PLATFORM
#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#endif  // IS_MOBILE_PLATFORM


namespace tensorflow {


SimpleGraphExecutionState::SimpleGraphExecutionState(
    GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options)
    : stateful_placements_(options.stateful_placements),
      device_set_(options.device_set),
      session_options_(options.session_options),
      costs_(true /*is_global*/),
      flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(),
                                              graph_def->library())),
      graph_(nullptr) {
  // NOTE(mrry): GraphDef does not have a move constructor, so we pass
  // a non-const pointer and use `Swap()` to transfer the contents without copying.
  // GraphDef 没有移动构造函数,所以我们传递一个 non-const 指针,并使用`Swap()`来传输内容而不复制。
  original_graph_def_.Swap(graph_def);
  // TODO(mrry): Publish placement visualizations or handle the log placement option.
  // 发布位置可视化或处理日志位置选项。
}


SimpleGraphExecutionState::~SimpleGraphExecutionState() {
  node_name_to_cost_id_map_.clear();
  delete graph_;
}


/* static */ Status SimpleGraphExecutionState::MakeForBaseGraph(
    GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options,
    std::unique_ptr<SimpleGraphExecutionState>* out_state) {
  std::unique_ptr<SimpleGraphExecutionState> ret(
      new SimpleGraphExecutionState(graph_def, options));


  TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&ret->original_graph_def_,
                                               *ret->flib_def_.get(), 0));
  // TODO(mrry): Refactor InitBaseGraph() so that we don't have to
  // pass an empty BuildGraphOptions (that isn't going to be used when
  // place_pruned_graph is false).
  // 重构 InitBaseGraph() 使我们不必传递一个空的 BuildGraphOptions(当 place_pruned_graph 为 false 时不会被使用)。
  if (!ret->session_options_->config.graph_options().place_pruned_graph()) {
    TF_RETURN_IF_ERROR(ret->InitBaseGraph(BuildGraphOptions()));
  }
  *out_state = std::move(ret);
  return Status::OK();
}


/* static */ Status SimpleGraphExecutionState::MakeForPrunedGraph(
    const FunctionDefLibrary& func_def_lib,
    const SimpleGraphExecutionStateOptions& options, const GraphDef& graph_def,
    const BuildGraphOptions& subgraph_options,
    std::unique_ptr<SimpleGraphExecutionState>* out_state,
    std::unique_ptr<SimpleClientGraph>* out_client_graph) {
  DCHECK(options.session_options->config.graph_options().place_pruned_graph());
  // NOTE(mrry): This makes a copy of `graph_def`, which is
  // regrettable. We could make `GraphDef` objects sharable between
  // execution states to optimize pruned graph execution, but since
  // this case is primarily used for interactive sessions, we make the
  // bet that graph construction is not performance-critical. (Note
  // also that the previous version used `Extend()`, which is strictly
  // more expensive than copying a `GraphDef`.)
  // 这是一个 `graph_def` 的副本,这是令人遗憾的。
  // 我们可以使 `GraphDef` 对象在执行状态之间共享,以优化修剪图执行,但是由于这种情况主要用于交互式会话,所以我们打赌图形构造不是性能关键。
  //(另请注意,以前的版本使用`Extend()`,这比复制 `GraphDef` 要严格得多)
  GraphDef temp(graph_def);
  std::unique_ptr<SimpleGraphExecutionState> ret(
      new SimpleGraphExecutionState(&temp, options));
  TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&ret->original_graph_def_,
                                               *ret->flib_def_.get(), 0));
  TF_RETURN_IF_ERROR(ret->InitBaseGraph(subgraph_options));
  TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph));
  *out_state = std::move(ret);
  return Status::OK();
}


Status SimpleGraphExecutionState::Extend(
    const GraphDef& extension_def,
    std::unique_ptr<SimpleGraphExecutionState>* out) const {
  std::unordered_set<string> new_names;
  // 1. Build an index of the new node names.
  // 1. 构建新节点名称的索引。
  for (const NodeDef& node : extension_def.node()) {
    new_names.insert(node.name());
  }


  // 2. Add the non-duplicates from the old graph to the new graph.
  //    Return an error if the same node name appears in both the
  //    old graph and the extension.
  // 2. 将不重复的旧图形添加到新图形。如果同一个节点名称出现在旧图形和扩展名中,则返回错误。
  GraphDef gdef;
  for (const NodeDef& node : original_graph_def_.node()) {
    if (new_names.count(node.name()) == 0) {
      *gdef.add_node() = node;
    } else {
      return errors::InvalidArgument(tensorflow::strings::Printf(
          "GraphDef argument to Extend includes node '%s', which was created "
          "by a previous call to Create or Extend in this session.",
          node.name().c_str()));
    }
  }


  // 3. Merge the versions field.
  // 3. 合并版本字段。
  int old_node_size = gdef.node_size();
  gdef.mutable_node()->MergeFrom(extension_def.node());
  TF_RETURN_IF_ERROR(
      AddDefaultAttrsToGraphDef(&gdef, *flib_def_.get(), old_node_size));
  // Merge versions
  if (gdef.has_versions()) {
    if (gdef.versions().producer() != extension_def.versions().producer()) {
      return errors::InvalidArgument(
          "Can't extend GraphDef at version ", gdef.versions().producer(),
          " with graph at version ", extension_def.versions().producer());
    }
    VersionDef* versions = gdef.mutable_versions();
    versions->set_min_consumer(std::max(
        versions->min_consumer(), extension_def.versions().min_consumer()));
    if (extension_def.versions().bad_consumers_size()) {
      // Add new bad_consumers that aren't already marked bad.
      // 添加尚未标记为坏的新的 bad_consumers。
	  
      // Note: This implementation is quadratic time if there are many calls to
      // ExtendLocked with many bad consumers.  Since this is unlikely, and
      // fixing it would require data structures outside of this routine,
      // quadratic time it is.
	  // 注意:如果对 ExtendLocked 有很多不好的消费者的呼叫很多,这个实现是二次的。
	  // 由于这不太可能,并且修复它将需要这个例程之外的数据结构,它是二次的。
      auto* bad_consumers = versions->mutable_bad_consumers();
      const std::unordered_set<int> existing(bad_consumers->begin(),
                                             bad_consumers->end());
      for (const int v : extension_def.versions().bad_consumers()) {
        if (existing.find(v) == existing.end()) {
          bad_consumers->Add(v);
        }
      }
    }


  } else {
    gdef.mutable_versions()->CopyFrom(extension_def.versions());
  }


  // 4. Copy the function library from this execution state.
  // 4. 从这个执行状态复制函数库。
  // NOTE(mrry): To match the previous behavior, the first GraphDef
  // passed to a session will contain the function library that is
  // used for all subsequent execution states.
  // 为了匹配以前的行为,传递给会话的第一个 GraphDef 将包含用于所有后续执行状态的函数库。
  *gdef.mutable_library() = flib_def_->ToProto();


  // 5. Validate that the final graphdef is valid.
  // 5. 验证最终的graphdef是否有效。
  if (gdef.versions().producer() >= 5) {
    // Validate the graph: we assume that merging two valid graphs
    // should maintain graph validity.
	// 验证图:我们假设合并两个有效的图形应该保持图形的有效性。
    TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_.get()));
  }


  // 6. Add the extension.
  SimpleGraphExecutionStateOptions combined_options;
  combined_options.device_set = device_set_;
  combined_options.session_options = session_options_;
  combined_options.stateful_placements = stateful_placements_;


  // NOTE(mrry): `gdef` is no longer valid after the constructor executes.
  // 在构造函数执行后,`gdef` 不再有效。
  std::unique_ptr<SimpleGraphExecutionState> new_execution_state(
      new SimpleGraphExecutionState(&gdef, combined_options));


  TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
      &new_execution_state->original_graph_def_, *flib_def_.get(), 0));
  if (!session_options_->config.graph_options().place_pruned_graph()) {
    // TODO(mrry): Refactor InitBaseGraph() so that we don't have to
    // pass an empty BuildGraphOptions (that isn't going to be used
    // when place_pruned_graph is false).
	// 重构 InitBaseGraph() 使我们不必传递一个空的 BuildGraphOptions(当 place_pruned_graph 为 false 时不会被使用)。
    TF_RETURN_IF_ERROR(new_execution_state->InitBaseGraph(BuildGraphOptions()));
  }
  *out = std::move(new_execution_state);


  // TODO(mrry): This is likely to be used for non-throughput-sensitive
  // interactive workloads, but in future we may want to transfer other
  // parts of the placement and/or cost model.
  // 这很可能用于非吞吐量敏感的交互式工作负载,但将来我们可能希望转移其他部分的展示位置 and/or 成本模型。
  return Status::OK();
}


void SimpleGraphExecutionState::SaveStatefulNodes(Graph* graph) {
  for (Node* n : graph->nodes()) {
    if (n->op_def().is_stateful()) {
      VLOG(2) << "Saving " << n->DebugString();
      stateful_placements_[n->name()] = n->assigned_device_name();
    }
  }
}


void SimpleGraphExecutionState::RestoreStatefulNodes(Graph* graph) {
  for (Node* n : graph->nodes()) {
    if (n->op_def().is_stateful()) {
      auto iter = stateful_placements_.find(n->name());
      if (iter != stateful_placements_.end()) {
        n->set_assigned_device_name(iter->second);
        VLOG(2) << "Restored " << n->DebugString();
      }
    }
  }
}


Status SimpleGraphExecutionState::InitBaseGraph(
    const BuildGraphOptions& options) {
  const GraphDef* graph_def = &original_graph_def_;


#ifndef IS_MOBILE_PLATFORM
  GraphDef optimized_graph;


  const RewriterConfig& rewrite_options =
      session_options_->config.graph_options().rewrite_options();


  if (grappler::MetaOptimizerEnabled(rewrite_options)) {
    // Adding this functionalty in steps. The first step is to make sure
    // we don't break dependencies. The second step will be to turn the
    // functionality on by default.
	// 按步骤添加此功能。 第一步是确保我们不会破坏依赖关系。 第二步是默认打开功能。
    grappler::GrapplerItem item;
    item.id = "tf_graph";
    item.graph = original_graph_def_;


    item.fetch = options.fetch_endpoints;
    item.fetch.insert(item.fetch.end(), options.target_nodes.begin(),
                      options.target_nodes.end());


    Status s;
    if (!options.feed_endpoints.empty()) {
      std::unordered_set<string> feeds(options.feed_endpoints.begin(),
                                       options.feed_endpoints.end());
      for (const NodeDef& node : original_graph_def_.node()) {
        if (feeds.find(node.name()) == feeds.end()) {
          continue;
        }
        if (node.attr().count("dtype") == 0 ||
            node.attr().count("shape") == 0) {
          s = errors::InvalidArgument("Missing node shape or type");
          break;
        }
        TensorShape shape(node.attr().at("shape").shape());
        DataType type = node.attr().at("dtype").type();
        Tensor fake_input(type, shape);
        item.feed.emplace_back(node.name(), fake_input);
      }
    }


    if (s.ok()) {
      std::unordered_map<string, DeviceProperties> device_map;
      for (const auto& device : device_set_->devices()) {
        device_map[device->name()] =
            grappler::GetDeviceInfo(device->parsed_name());
      }
      grappler::VirtualCluster cluster(device_map);
      s = grappler::RunMetaOptimizer(item, rewrite_options, &cluster,
                                     &optimized_graph);
    }
    if (s.ok()) {
      graph_def = &optimized_graph;
    }
  }
#endif  // IS_MOBILE_PLATFORM


  std::unique_ptr<Graph> new_graph(new Graph(OpRegistry::Global()));
  GraphConstructorOptions opts;
  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, *graph_def, new_graph.get()));
  for (const Node* n : new_graph->nodes()) {
    VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
    node_name_to_cost_id_map_[n->name()] = n->cost_id();
  }
  if (session_options_ &&
      session_options_->config.graph_options().place_pruned_graph()) {
    // Rewrite the graph before placement. 在放置前重写图形。
    rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
    TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
        new_graph.get(), options.feed_endpoints, options.fetch_endpoints,
        options.target_nodes, device_set_->client_device()->attributes(),
        options.use_function_convention, rewrite_metadata_.get()));
  }
 
  // Save stateful placements before placing. 保存有状态的位置。
  RestoreStatefulNodes(new_graph.get());


  CostModel costs(true /*is_global*/);
  {
    mutex_lock l(mu_);
    costs_.InitFromGraph(*new_graph.get());
    costs.MergeFromGlobal(costs_);
  }


  GraphOptimizationPassOptions optimization_options;
  optimization_options.session_options = session_options_;
  optimization_options.graph = &new_graph;
  optimization_options.flib_def = flib_def_.get();
  optimization_options.device_set = device_set_;
  optimization_options.cost_model = &costs;


  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
      OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));


  SimplePlacer placer(new_graph.get(), device_set_, session_options_);
  // TODO(mrry): Consider making the SimplePlacer cancelable.
  TF_RETURN_IF_ERROR(placer.Run());


  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
      OptimizationPassRegistry::POST_PLACEMENT, optimization_options));


  SaveStatefulNodes(new_graph.get());
  graph_ = new_graph.release();
  return Status::OK();
}


Status SimpleGraphExecutionState::BuildGraph(
    const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) {
  VLOG(1) << "BuildGraph";
  if (!graph_) {
    // It is only valid to call this method directly when the original graph
    // was created with the option `place_pruned_graph == false`.
	// 当使用`place_pruned_graph == false`选项创建原始图形时,直接调用此方法才有效。
    return errors::Internal(
        "Attempted to prune a graph that has not been fully initialized.");
  }
  std::unique_ptr<Graph> ng(new Graph(flib_def_.get()));
  CopyGraph(*graph_, ng.get());


  subgraph::RewriteGraphMetadata rewrite_metadata;
  if (session_options_ == nullptr ||
      !session_options_->config.graph_options().place_pruned_graph()) {
    // Extract the subset of the graph that needs to be run, adding feed/fetch ops as needed.
    // 提取需要运行的图表子集,根据需要添加 feed/fetch 操作。
    TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
        ng.get(), options.feed_endpoints, options.fetch_endpoints,
        options.target_nodes, device_set_->client_device()->attributes(),
        options.use_function_convention, &rewrite_metadata));
  } else {
    // This SimpleGraphExecutionState represents a graph that was pruned when this was constructed,
    // so we copy the metadata from a member variable.
    // 这个 SimpleGraphExecutionState 表示一个在构造时被修剪的图形,所以我们从成员变量复制元数据。
    CHECK(rewrite_metadata_);
    rewrite_metadata = *rewrite_metadata_;
  }


  CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size());
  CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size());


  // Make a fresh copy of the function library for the client graph.
  // 为客户端图创建功能库的新副本。
  std::unique_ptr<FunctionLibraryDefinition> flib(
      new FunctionLibraryDefinition(*flib_def_));


  // TODO(andydavis): Clarify optimization pass requirements around CostModel.
  CostModel costs(true /*is_global*/);
  costs.MergeFromGlobal(costs_);
  GraphOptimizationPassOptions optimization_options;
  optimization_options.session_options = session_options_;
  optimization_options.graph = &ng;
  optimization_options.flib_def = flib.get();
  optimization_options.device_set = device_set_;
  optimization_options.cost_model = &costs;


  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
      OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));


  // Copy the extracted graph in order to make its node ids dense,
  // since the local CostModel used to record its stats is sized by
  // the largest node id.
  // 复制提取的图形以使其节点密集,因为用于记录其统计信息的本地 CostModel 的大小由最大的节点 ID 定义。
  std::unique_ptr<SimpleClientGraph> dense_copy(
      new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types,
                            rewrite_metadata.fetch_types));
  CopyGraph(*ng, &dense_copy->graph);


  // TODO(vrv): We should check invariants of the graph here.


  *out = std::move(dense_copy);
  return Status::OK();
}


}  // namespace tensorflow


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值