simple_graph_execution_state.h
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/common_runtime/build_graph_options.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
struct SessionOptions;
class StepStats;
class Timeline;
namespace subgraph {
struct RewriteGraphMetadata;
}
struct SimpleGraphExecutionStateOptions {
const DeviceSet* device_set = nullptr;
const SessionOptions* session_options = nullptr;
// A map from node name to device name, representing the unchangeable
// placement of stateful nodes.
// 从节点名称到设备名称的映射,表示状态节点不可更改的位置。
std::unordered_map<string, string> stateful_placements;
};
// A SimpleClientGraph is simply a sub-graph of the full graph as induced by BuildGraphOptions.
// SimpleClientGraph 只是 BuildGraphOptions 引发的完整图形的子图。
struct SimpleClientGraph {
explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
DataTypeVector feed_types,
DataTypeVector fetch_types)
: flib_def(std::move(flib)),
graph(flib_def.get()),
feed_types(std::move(feed_types)),
fetch_types(std::move(fetch_types)) {}
// Each client-graph gets its own function library since optimization passes
// post rewrite for execution might want to introduce new functions.
// 每个客户端图获得自己的函数库,因为优化通过后重写以执行可能需要引入新的函数。
std::unique_ptr<FunctionLibraryDefinition> flib_def;
Graph graph;
DataTypeVector feed_types;
DataTypeVector fetch_types;
};
// SimpleGraphExecutionState is responsible for generating an
// executable SimpleClientGraph from the original GraphDef that specifies
// the complete graph and from BuildGraphOptions which specifies
// input/output nodes.
// SimpleGraphExecutionState 负责从原始的 GraphDef 中生成一个可执行的 SimpleClientGraph,
// 该 GraphDef 指定了完整的图形,并且指定了指定输入/输出节点的 BuildGraphOptions。
//
// An executable Graph differs from a GraphDef by being Placed,
// meaning that each Node is assigned to a single Device in the
// available set.
// 可执行图形与 GraphDef 的不同之处在于放置,这意味着每个节点都分配给可用集合中的单个设备。
//
// When SimpleGraphExecutionState is first constructed it instantiates
// a full Graph from the provided GraphDef, and places it, using only
// the static device assignments from the GraphDef. Nodes without are
// currently placed in a very naive way. Since stateful Nodes cannot
// be moved after initial placement, it is important that stateful
// Nodes get sensible initial device assignments in the graph
// definition.
// 当 SimpleGraphExecutionState 首次构建时,它将从提供的 GraphDef 中实例化一个完整的图形,并将其放置,
// 仅使用 GraphDef 中的静态设备分配。没有的节点目前处于非常幼稚的方式。
// 由于有状态节点在初始放置之后无法移动,因此状态很重要。节点在图形定义中获得明智的初始设备分配。
//
// Subsequently, SimpleGraphExecutionState generates a SimpleClientGraph on
// demand, which is a sub-graph of the latest placement of the full
// Graph. MasterSession uses such a SimpleClientGraph to execute one or
// more similar client requests.
// 随后,SimpleGraphExecutionState 根据需要生成一个 SimpleClientGraph,它是完整图形的最新位置的子图。
// MasterSession 使用这样的 SimpleClientGraph 来执行一个或多个类似的客户端请求。
//
// SimpleGraphExecutionState is thread-safe.
// SimpleGraphExecutionState 是线程安全的。
class SimpleGraphExecutionState {
public:
virtual ~SimpleGraphExecutionState();
// Creates a new `SimpleGraphExecutionState` for the given
// `graph_def`, which represents the entire graph for a session.
// 为给定的`graph_def`创建一个新的`SimpleGraphExecutionState`,它表示会话的整个图形。
// N.B. This method uses `GraphDef::Swap()` and leaves `graph_def`
// in an undefined state. If it is necessary to use `*graph_def`
// after this call, make an explicit copy of the graph before
// calling this method.
// 该方法使用`GraphDef::Swap()`,并将 `graph_def` 留在未定义的状态。
// 如果在此调用后需要使用`*graph_def`,则在调用此方法之前,请先创建图形的显式副本。
static Status MakeForBaseGraph(
GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options,
std::unique_ptr<SimpleGraphExecutionState>* out_state);
// Creates a new `SimpleGraphExecutionState` and `SimpleClientGraph`
// for the subgraph of `original_graph_def` defined by `subgraph_options`.
// 为`subgraph_options`定义的`original_graph_def`子图创建一个新的`SimpleGraphExecutionState`和`SimpleClientGraph`。
static Status MakeForPrunedGraph(
const FunctionDefLibrary& func_def_lib,
const SimpleGraphExecutionStateOptions& options,
const GraphDef& original_graph_def,
const BuildGraphOptions& subgraph_options,
std::unique_ptr<SimpleGraphExecutionState>* out_state,
std::unique_ptr<SimpleClientGraph>* out_client_graph);
// Creates a new SimpleGraphExecutionState representing the
// concatenation of this graph, and the graph defined by
// "extension_def". The same name may not be used to define a node
// in both this graph and "extension_def".
// 创建一个新的 SimpleGraphExecutionState,表示此图形的连接,以及由 "xtension_def" 定义的图形。
// 同一个名称可能不用于在此图和"extension_def"中定义节点。
//
// If successful, returns OK and the caller takes ownership of "*out".
// Otherwise returns an error and does not modify "*out".
// 如果成功,返回 OK,调用者拥有"*out"的所有权。 否则返回错误,不会修改"*out"。
//
// After calling `old_state->Extend()`, `old_state` may no longer be used.
// 调用 `old_state->Extend()`后,可能不再使用`old_state`。
//
// NOTE(mrry): This method respects the placement of stateful nodes in
// in *this, but currently does not transfer any other placement
// or cost model information to the new graph.
// 此方法尊重在 *this 中的状态节点的位置,但是目前不会将任何其他位置或成本模型信息传输到新的图形。
Status Extend(const GraphDef& extension_def,
std::unique_ptr<SimpleGraphExecutionState>* out) const;
// Builds a SimpleClientGraph (a sub-graph of the full graph as induced by
// the Node set specified in "options"). If successful, returns OK
// and the caller takes the ownership of "*out". Otherwise, returns an error.
// 构建 SimpleClientGraph(由"options"中指定的 Node 集合引发的完整图形的子图)。
// 如果成功,返回OK,调用者拥有"*out"的所有权。 否则返回错误。
Status BuildGraph(const BuildGraphOptions& options,
std::unique_ptr<SimpleClientGraph>* out);
// The graph returned by BuildGraph may contain only the pruned
// graph, whereas some clients may want access to the full graph.
// BuildGraph 返回的图可能只包含已修剪的图形,而某些客户端可能希望访问完整的图形。
const Graph* full_graph() {
return graph_;
}
// Returns the node with the given name, or null if it does not exist.
// 返回具有给定名称的节点,如果不存在则返回 null。
const Node* get_node_by_name(const string& name) const {
NodeNameToCostIdMap::const_iterator iter =
node_name_to_cost_id_map_.find(name);
if (iter != node_name_to_cost_id_map_.end()) {
return graph_->FindNodeId(iter->second);
} else {
return nullptr;
}
}
// Returns a reference to the current graph_def. Use must
// not extend beyond lifetime of SimpleGrahExecutionState object.
// 返回对当前 graph_def 的引用。 使用不能超出 SimpleGrahExecutionState 对象的生命周期。
const GraphDef& original_graph_def() { return original_graph_def_; }
// Returns the map of stateful placements as a map of node name to placement string.
// 将状态展示位置的地图作为节点名称的映射返回到展示位置字符串。
std::unordered_map<string, string> GetStatefulPlacements() const {
mutex_lock l(mu_);
return stateful_placements_;
}
private:
SimpleGraphExecutionState(GraphDef* graph_def,
const SimpleGraphExecutionStateOptions& options);
Status InitBaseGraph(const BuildGraphOptions& options);
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
// is true, such as "params" and "queue" nodes. Once placed these
// nodes can not be moved to a different device. Maps node names to
// device names.
// 放置状态节点的映射,即 is_stateful() 为 true 的节点,如"params"和"queue"节点。
// 一旦放置这些节点不能移动到不同的设备。 将节点名称映射到设备名称。
std::unordered_map<string, string> stateful_placements_; // Immutable after ctor.
void SaveStatefulNodes(Graph* graph);
void RestoreStatefulNodes(Graph* graph);
GraphDef original_graph_def_; // Immutable after ctor.
const DeviceSet* device_set_; // Not owned
const SessionOptions* session_options_; // Not owned
mutable mutex mu_;
CostModel costs_ GUARDED_BY(mu_);
// Map from name to Node for the full graph in placed_.
NodeNameToCostIdMap node_name_to_cost_id_map_;
// 'flib_def_' is initialized from the initial graph def's library,
// and may be updated by a graph optimization pass.
// 'flib_def_'初始化从初始图形 def 的库,可以通过图优化通过来更新。
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
// `rewrite_metadata_` is only set for SimpleGraphExecutionState
// objects created by `MakeForPrunedGraph()`.
// `rewrite_metadata_`仅为由`MakeForPrunedGraph()`创建的 SimpleGraphExecutionState 对象设置。
std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
// The dataflow graph owned by this object. 该对象拥有的数据流图。
Graph* graph_;
TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
simple_graph_execution_state.cc
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/simple_placer.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
#ifndef IS_MOBILE_PLATFORM
#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#endif // IS_MOBILE_PLATFORM
namespace tensorflow {
SimpleGraphExecutionState::SimpleGraphExecutionState(
GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options)
: stateful_placements_(options.stateful_placements),
device_set_(options.device_set),
session_options_(options.session_options),
costs_(true /*is_global*/),
flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(),
graph_def->library())),
graph_(nullptr) {
// NOTE(mrry): GraphDef does not have a move constructor, so we pass
// a non-const pointer and use `Swap()` to transfer the contents without copying.
// GraphDef 没有移动构造函数,所以我们传递一个 non-const 指针,并使用`Swap()`来传输内容而不复制。
original_graph_def_.Swap(graph_def);
// TODO(mrry): Publish placement visualizations or handle the log placement option.
// 发布位置可视化或处理日志位置选项。
}
SimpleGraphExecutionState::~SimpleGraphExecutionState() {
node_name_to_cost_id_map_.clear();
delete graph_;
}
/* static */ Status SimpleGraphExecutionState::MakeForBaseGraph(
GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options,
std::unique_ptr<SimpleGraphExecutionState>* out_state) {
std::unique_ptr<SimpleGraphExecutionState> ret(
new SimpleGraphExecutionState(graph_def, options));
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&ret->original_graph_def_,
*ret->flib_def_.get(), 0));
// TODO(mrry): Refactor InitBaseGraph() so that we don't have to
// pass an empty BuildGraphOptions (that isn't going to be used when
// place_pruned_graph is false).
// 重构 InitBaseGraph() 使我们不必传递一个空的 BuildGraphOptions(当 place_pruned_graph 为 false 时不会被使用)。
if (!ret->session_options_->config.graph_options().place_pruned_graph()) {
TF_RETURN_IF_ERROR(ret->InitBaseGraph(BuildGraphOptions()));
}
*out_state = std::move(ret);
return Status::OK();
}
/* static */ Status SimpleGraphExecutionState::MakeForPrunedGraph(
const FunctionDefLibrary& func_def_lib,
const SimpleGraphExecutionStateOptions& options, const GraphDef& graph_def,
const BuildGraphOptions& subgraph_options,
std::unique_ptr<SimpleGraphExecutionState>* out_state,
std::unique_ptr<SimpleClientGraph>* out_client_graph) {
DCHECK(options.session_options->config.graph_options().place_pruned_graph());
// NOTE(mrry): This makes a copy of `graph_def`, which is
// regrettable. We could make `GraphDef` objects sharable between
// execution states to optimize pruned graph execution, but since
// this case is primarily used for interactive sessions, we make the
// bet that graph construction is not performance-critical. (Note
// also that the previous version used `Extend()`, which is strictly
// more expensive than copying a `GraphDef`.)
// 这是一个 `graph_def` 的副本,这是令人遗憾的。
// 我们可以使 `GraphDef` 对象在执行状态之间共享,以优化修剪图执行,但是由于这种情况主要用于交互式会话,所以我们打赌图形构造不是性能关键。
//(另请注意,以前的版本使用`Extend()`,这比复制 `GraphDef` 要严格得多)
GraphDef temp(graph_def);
std::unique_ptr<SimpleGraphExecutionState> ret(
new SimpleGraphExecutionState(&temp, options));
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&ret->original_graph_def_,
*ret->flib_def_.get(), 0));
TF_RETURN_IF_ERROR(ret->InitBaseGraph(subgraph_options));
TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph));
*out_state = std::move(ret);
return Status::OK();
}
Status SimpleGraphExecutionState::Extend(
const GraphDef& extension_def,
std::unique_ptr<SimpleGraphExecutionState>* out) const {
std::unordered_set<string> new_names;
// 1. Build an index of the new node names.
// 1. 构建新节点名称的索引。
for (const NodeDef& node : extension_def.node()) {
new_names.insert(node.name());
}
// 2. Add the non-duplicates from the old graph to the new graph.
// Return an error if the same node name appears in both the
// old graph and the extension.
// 2. 将不重复的旧图形添加到新图形。如果同一个节点名称出现在旧图形和扩展名中,则返回错误。
GraphDef gdef;
for (const NodeDef& node : original_graph_def_.node()) {
if (new_names.count(node.name()) == 0) {
*gdef.add_node() = node;
} else {
return errors::InvalidArgument(tensorflow::strings::Printf(
"GraphDef argument to Extend includes node '%s', which was created "
"by a previous call to Create or Extend in this session.",
node.name().c_str()));
}
}
// 3. Merge the versions field.
// 3. 合并版本字段。
int old_node_size = gdef.node_size();
gdef.mutable_node()->MergeFrom(extension_def.node());
TF_RETURN_IF_ERROR(
AddDefaultAttrsToGraphDef(&gdef, *flib_def_.get(), old_node_size));
// Merge versions
if (gdef.has_versions()) {
if (gdef.versions().producer() != extension_def.versions().producer()) {
return errors::InvalidArgument(
"Can't extend GraphDef at version ", gdef.versions().producer(),
" with graph at version ", extension_def.versions().producer());
}
VersionDef* versions = gdef.mutable_versions();
versions->set_min_consumer(std::max(
versions->min_consumer(), extension_def.versions().min_consumer()));
if (extension_def.versions().bad_consumers_size()) {
// Add new bad_consumers that aren't already marked bad.
// 添加尚未标记为坏的新的 bad_consumers。
// Note: This implementation is quadratic time if there are many calls to
// ExtendLocked with many bad consumers. Since this is unlikely, and
// fixing it would require data structures outside of this routine,
// quadratic time it is.
// 注意:如果对 ExtendLocked 有很多不好的消费者的呼叫很多,这个实现是二次的。
// 由于这不太可能,并且修复它将需要这个例程之外的数据结构,它是二次的。
auto* bad_consumers = versions->mutable_bad_consumers();
const std::unordered_set<int> existing(bad_consumers->begin(),
bad_consumers->end());
for (const int v : extension_def.versions().bad_consumers()) {
if (existing.find(v) == existing.end()) {
bad_consumers->Add(v);
}
}
}
} else {
gdef.mutable_versions()->CopyFrom(extension_def.versions());
}
// 4. Copy the function library from this execution state.
// 4. 从这个执行状态复制函数库。
// NOTE(mrry): To match the previous behavior, the first GraphDef
// passed to a session will contain the function library that is
// used for all subsequent execution states.
// 为了匹配以前的行为,传递给会话的第一个 GraphDef 将包含用于所有后续执行状态的函数库。
*gdef.mutable_library() = flib_def_->ToProto();
// 5. Validate that the final graphdef is valid.
// 5. 验证最终的graphdef是否有效。
if (gdef.versions().producer() >= 5) {
// Validate the graph: we assume that merging two valid graphs
// should maintain graph validity.
// 验证图:我们假设合并两个有效的图形应该保持图形的有效性。
TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_.get()));
}
// 6. Add the extension.
SimpleGraphExecutionStateOptions combined_options;
combined_options.device_set = device_set_;
combined_options.session_options = session_options_;
combined_options.stateful_placements = stateful_placements_;
// NOTE(mrry): `gdef` is no longer valid after the constructor executes.
// 在构造函数执行后,`gdef` 不再有效。
std::unique_ptr<SimpleGraphExecutionState> new_execution_state(
new SimpleGraphExecutionState(&gdef, combined_options));
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
&new_execution_state->original_graph_def_, *flib_def_.get(), 0));
if (!session_options_->config.graph_options().place_pruned_graph()) {
// TODO(mrry): Refactor InitBaseGraph() so that we don't have to
// pass an empty BuildGraphOptions (that isn't going to be used
// when place_pruned_graph is false).
// 重构 InitBaseGraph() 使我们不必传递一个空的 BuildGraphOptions(当 place_pruned_graph 为 false 时不会被使用)。
TF_RETURN_IF_ERROR(new_execution_state->InitBaseGraph(BuildGraphOptions()));
}
*out = std::move(new_execution_state);
// TODO(mrry): This is likely to be used for non-throughput-sensitive
// interactive workloads, but in future we may want to transfer other
// parts of the placement and/or cost model.
// 这很可能用于非吞吐量敏感的交互式工作负载,但将来我们可能希望转移其他部分的展示位置 and/or 成本模型。
return Status::OK();
}
void SimpleGraphExecutionState::SaveStatefulNodes(Graph* graph) {
for (Node* n : graph->nodes()) {
if (n->op_def().is_stateful()) {
VLOG(2) << "Saving " << n->DebugString();
stateful_placements_[n->name()] = n->assigned_device_name();
}
}
}
void SimpleGraphExecutionState::RestoreStatefulNodes(Graph* graph) {
for (Node* n : graph->nodes()) {
if (n->op_def().is_stateful()) {
auto iter = stateful_placements_.find(n->name());
if (iter != stateful_placements_.end()) {
n->set_assigned_device_name(iter->second);
VLOG(2) << "Restored " << n->DebugString();
}
}
}
}
Status SimpleGraphExecutionState::InitBaseGraph(
const BuildGraphOptions& options) {
const GraphDef* graph_def = &original_graph_def_;
#ifndef IS_MOBILE_PLATFORM
GraphDef optimized_graph;
const RewriterConfig& rewrite_options =
session_options_->config.graph_options().rewrite_options();
if (grappler::MetaOptimizerEnabled(rewrite_options)) {
// Adding this functionalty in steps. The first step is to make sure
// we don't break dependencies. The second step will be to turn the
// functionality on by default.
// 按步骤添加此功能。 第一步是确保我们不会破坏依赖关系。 第二步是默认打开功能。
grappler::GrapplerItem item;
item.id = "tf_graph";
item.graph = original_graph_def_;
item.fetch = options.fetch_endpoints;
item.fetch.insert(item.fetch.end(), options.target_nodes.begin(),
options.target_nodes.end());
Status s;
if (!options.feed_endpoints.empty()) {
std::unordered_set<string> feeds(options.feed_endpoints.begin(),
options.feed_endpoints.end());
for (const NodeDef& node : original_graph_def_.node()) {
if (feeds.find(node.name()) == feeds.end()) {
continue;
}
if (node.attr().count("dtype") == 0 ||
node.attr().count("shape") == 0) {
s = errors::InvalidArgument("Missing node shape or type");
break;
}
TensorShape shape(node.attr().at("shape").shape());
DataType type = node.attr().at("dtype").type();
Tensor fake_input(type, shape);
item.feed.emplace_back(node.name(), fake_input);
}
}
if (s.ok()) {
std::unordered_map<string, DeviceProperties> device_map;
for (const auto& device : device_set_->devices()) {
device_map[device->name()] =
grappler::GetDeviceInfo(device->parsed_name());
}
grappler::VirtualCluster cluster(device_map);
s = grappler::RunMetaOptimizer(item, rewrite_options, &cluster,
&optimized_graph);
}
if (s.ok()) {
graph_def = &optimized_graph;
}
}
#endif // IS_MOBILE_PLATFORM
std::unique_ptr<Graph> new_graph(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, *graph_def, new_graph.get()));
for (const Node* n : new_graph->nodes()) {
VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
node_name_to_cost_id_map_[n->name()] = n->cost_id();
}
if (session_options_ &&
session_options_->config.graph_options().place_pruned_graph()) {
// Rewrite the graph before placement. 在放置前重写图形。
rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
new_graph.get(), options.feed_endpoints, options.fetch_endpoints,
options.target_nodes, device_set_->client_device()->attributes(),
options.use_function_convention, rewrite_metadata_.get()));
}
// Save stateful placements before placing. 保存有状态的位置。
RestoreStatefulNodes(new_graph.get());
CostModel costs(true /*is_global*/);
{
mutex_lock l(mu_);
costs_.InitFromGraph(*new_graph.get());
costs.MergeFromGlobal(costs_);
}
GraphOptimizationPassOptions optimization_options;
optimization_options.session_options = session_options_;
optimization_options.graph = &new_graph;
optimization_options.flib_def = flib_def_.get();
optimization_options.device_set = device_set_;
optimization_options.cost_model = &costs;
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
SimplePlacer placer(new_graph.get(), device_set_, session_options_);
// TODO(mrry): Consider making the SimplePlacer cancelable.
TF_RETURN_IF_ERROR(placer.Run());
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
SaveStatefulNodes(new_graph.get());
graph_ = new_graph.release();
return Status::OK();
}
Status SimpleGraphExecutionState::BuildGraph(
const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) {
VLOG(1) << "BuildGraph";
if (!graph_) {
// It is only valid to call this method directly when the original graph
// was created with the option `place_pruned_graph == false`.
// 当使用`place_pruned_graph == false`选项创建原始图形时,直接调用此方法才有效。
return errors::Internal(
"Attempted to prune a graph that has not been fully initialized.");
}
std::unique_ptr<Graph> ng(new Graph(flib_def_.get()));
CopyGraph(*graph_, ng.get());
subgraph::RewriteGraphMetadata rewrite_metadata;
if (session_options_ == nullptr ||
!session_options_->config.graph_options().place_pruned_graph()) {
// Extract the subset of the graph that needs to be run, adding feed/fetch ops as needed.
// 提取需要运行的图表子集,根据需要添加 feed/fetch 操作。
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
ng.get(), options.feed_endpoints, options.fetch_endpoints,
options.target_nodes, device_set_->client_device()->attributes(),
options.use_function_convention, &rewrite_metadata));
} else {
// This SimpleGraphExecutionState represents a graph that was pruned when this was constructed,
// so we copy the metadata from a member variable.
// 这个 SimpleGraphExecutionState 表示一个在构造时被修剪的图形,所以我们从成员变量复制元数据。
CHECK(rewrite_metadata_);
rewrite_metadata = *rewrite_metadata_;
}
CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size());
CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size());
// Make a fresh copy of the function library for the client graph.
// 为客户端图创建功能库的新副本。
std::unique_ptr<FunctionLibraryDefinition> flib(
new FunctionLibraryDefinition(*flib_def_));
// TODO(andydavis): Clarify optimization pass requirements around CostModel.
CostModel costs(true /*is_global*/);
costs.MergeFromGlobal(costs_);
GraphOptimizationPassOptions optimization_options;
optimization_options.session_options = session_options_;
optimization_options.graph = &ng;
optimization_options.flib_def = flib.get();
optimization_options.device_set = device_set_;
optimization_options.cost_model = &costs;
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
// Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by
// the largest node id.
// 复制提取的图形以使其节点密集,因为用于记录其统计信息的本地 CostModel 的大小由最大的节点 ID 定义。
std::unique_ptr<SimpleClientGraph> dense_copy(
new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types,
rewrite_metadata.fetch_types));
CopyGraph(*ng, &dense_copy->graph);
// TODO(vrv): We should check invariants of the graph here.
*out = std::move(dense_copy);
return Status::OK();
}
} // namespace tensorflow