一、注册:
在tensorflow/compiler/jit/jit_compilation_pass_registration.cc对jit相关的优化器进行统一的注册。举例:
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass);
其中POST_REWRITE_FOR_EXEC指阶段。10指该阶段中该OPTIMIZATION第几个被调用。MarkForCompilationPass是需要注册的图优化Pass。
REGISTER_OPTIMIZATION 宏函数展开:
class OptimizationPassRegistration {
public:
OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping,
int phase,
std::unique_ptr<GraphOptimizationPass> pass,
string optimization_pass_name) {
pass->set_name(optimization_pass_name);
OptimizationPassRegistry::Global()->Register(grouping, phase,
std::move(pass));
}
};
} // namespace optimization_registration
#define REGISTER_OPTIMIZATION(grouping, phase, optimization) \
REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization)
#define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \
REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)
#define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \
static ::tensorflow::optimization_registration::OptimizationPassRegistration \
register_optimization_##ctr( \
grouping, phase, \
::std::unique_ptr<::tensorflow::GraphOptimizationPass>( \
new optimization()), \
#optimization)
} // namespace tensorflow
其实对每个调用REGISTER_OPTIMIZATION的优化器都构造了一个 static的该类的optimization。
然后将上述的optimization都放在std::map<Grouping, GraphOptimizationPasses> groups_;
其中Grouping就是多个阶段,GraphOptimizationPasses就是该阶段调用的哪些优化器。
跟Op的注册机制相似。
二、调用:
core/common_runtime/direct_session.cc
1652: TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
core/distributed_runtime/graph_mgr.cc
197: TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
core/common_runtime/graph_execution_state.cc
609: TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
621: TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
829: TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
在图的不同阶段会运行不同的优化器。
四个阶段:
enum Grouping {
PRE_PLACEMENT, // after cost model assignment, before placement.
POST_PLACEMENT, // after placement.
POST_REWRITE_FOR_EXEC, // after re-write using feed/fetch endpoints.
POST_PARTITIONING, // after partitioning
};