pass基础设施

本文翻译自Pass Infrastructure — tvm 0.9.dev0 documentation

Relay和TVM IR都包含一系列优化pass,用于改善模型的性能指标,如平均推断、内存占用或特定设备的功耗。有一套标准的优化和机器学习特有的优化,包括常量折叠、死代码消除、算子布局更改、算子融合、缓冲区处理和循环转换等。通过使用遍历期间和/或遍历之前收集的分析结果,将每个pass都被构造为ir-to-ir转换。

然而,随着TVM的迅速发展,对这些pass进行更系统和更有效管理越来越迫切。另外,一个管理跨TVM栈不同层(如Relay和tir)的pass的通用框架,为开发人员快速原型化和将实现的pass插入到系统中铺平了道路。

本文档描述了这样一个基础设施的设计,它利用了产品编译器管理优化pass的方式,以及现代深度学习框架中层构建的风格。

例如,许多现有的生产编译器,如GCC和LLVM,都采用了pass管理器来有效地管理pass的执行。最初管理pass是简单的,因为pass的数量很少,但成熟的编译器将包含数百个单独的pass。外部用户通常希望能够正确地安排自定义pass,而不需要人工安排pass顺序。

同样,现代的深度学习框架,如Pytorch和MXNet Gluon,也有通过Sequential和Block实现pass-style层构造方案的趋势。有了这样的结构,这些现代框架能够方便地将模块/层添加到它们的容器中,并轻松地构建神经网络。

Relay pass基础架构的设计很大程度上受到LLVM中使用的分层pass管理器,和流行的深度学习框架中使用的块式容器的启发。pass基础设施的主要目标包括:

  1. 支持更好的纲领性优化编排。这允许用户灵活地定制和构建自己的优化管道;
  2. 提供一种用户友好的方式来调试优化pass;
  3. 减轻开发人员人工地逐个地解决pass之间的依赖关系的工作量;
  4. 简化开发人员的新pass的实现。例如,我们允许用户在Python中实现一个pass,并让pass基础设施操作它的执行。

设计 

我们专注于为用户扩展易用性,使用户能够在不损失向后兼容性的情况下快速添加新的pass。该设计包括后端和前端。前者实现了pass基础设施的主要逻辑。后者提供了简单的API供用户交互,即允许用户快速创建自己的优化管道。

C++后端 

如下代码所示,我们提供一个PassInfo对象来包含pass所需的基本信息。name是pass的名称,opt_level表示在哪个优化级别将执行pass,required表示执行某个pass所需的pass(有关详细信息,请参阅include/tvm/ir/transform.h)。例如,在一个pass的注册过程中(将在后面讨论),pass开发人员可以指定pass的名称,它将执行的优化级别,和/或所需的pass。opt_level可用于帮助pass基础设施识别在用户提供的优化级别下运行时是否需要执行某个pass。required字段可以被pass基础架构用来解析pass依赖关系。

class PassInfoNode : public Object {
  String name;
  int opt_level;
  Array<String> required;
};

PassContext

PassContext为优化pass携带有用的信息。例如,它包含错误报告系统,这样优化开发人员就可以对优化失败的原因进行诊断。PassContext还被设计用来取代旧的BuildConfig,后者用于帮助用户配置编译选项,包括优化级别和所需/禁用的pass等。例如,我们可能有一个配置,在opt_level=3执行所有的pass,同时使用PassContext提供的disabled_pass=xx禁用一些pass。那么现在,我们可以查找opt_level=3的所有pass,然后排除那些在禁用pass列表中的pass。PassContext还提供了一种方法来检测所有的pass。参见Pass Instrument。 

这个类是为用户能便捷地使用语法编写Python而设计的,以便在特定的配置下执行优化。此外,用户可以通过PassContext::Current()以线程安全的方式获得在某个程序范围内可用的上下文,因为线程本地存储PassContextThreadLocalStore用于保存创建的pass上下文对象。后面将提供一些示例,展示如何通过C++和Python API使用pass上下文创建编译管道。

class PassContextNode : public Object {
 public:
  int opt_level{2};
  tvm::Array<tvm::Expr> required_pass;
  tvm::Array<tvm::Expr> disabled_pass;
  mutable Optional<DiagnosticContext> diag_ctx;
  Map<String, ObjectRef> config;
  Array<instrument::PassInstrument> instruments;
};

class PassContext : public NodeRef {
 public:
  TVM_DLL static PassContext Create();
  TVM_DLL static PassContext Current();
  TVM_DLL void InstrumentEnterPassContext();
  TVM_DLL void InstrumentExitPassContext();
  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
  /* Other fields are omitted. */

 private:
  // The entry of a pass context scope.
  TVM_DLL void EnterWithScope();
  // The exit of a pass context scope.
  TVM_DLL void ExitWithScope();

  // Classes to get the Python `with` like syntax.
  friend class tvm::With<PassContext>;
};

struct PassContextThreadLocalEntry {
  /*! \brief The default pass context. */
  PassContext default_context;
  /*! \brief The current pass context. */
  std::stack<PassContext> context_stack;
  PassContextThreadLocalEntry() {
    default_context = PassContext(make_node<PassContextNode>());
  }
};

/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
     PassContextThreadLocalStore;

 Pass构造

pass基础架构是分层设计的,可以在不同粒度的Relay/tir程序中工作。引入一个纯虚类PassNode作为不同优化pass的基础。该类包含几个虚方法,这些方法必须由模块、函数或pass序列子类实现。

class PassNode : Object {
  virtual PassInfo Info() const = 0;
  virtual Module operator()(const IRModule& mod
                            const PassContext& pass_ctx) const = 0;
};

 函数器展示了一个pass必须如何实现,也就是说,它总是在IRModule的特定上下文中工作。所有的pass都是以模块到模块的方式设计的。因此,由pass基础架构管理的优化将始终更新整个模块。

已经创建了几个子类来实现不同类型的优化pass,例如,函数级pass、模块级pass和序列pass。每个子类本身可以充当pass管理器。例如,他们可以收集所需的pass并执行它们,或者基于给定的元数据构建依赖关系图。它们的完整定义可以在src/relay/ir/transform.cc和src/ir/ transform.cc中找到。

模块级别pass

模块级pass主要用于全局和过程间优化(IPO),这与LLVM中使用的模块pass类似。Relay中一些需要模块全局图的典型pass,如a范式转换、lambda提升等,都属于这个集合。在这个级别上,用户甚至可以在模块中添加和/或删除函数。注意,所有的pass 

class ModulePassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  // Other members/methods are omitted
};

 pass_info维护模块级pass所需的信息。pass_func描述了真正的优化。例如,我们可能需要在模块上执行死代码消除。我们可以在pass_func中实现算法,并让它在一个模块上运行。然后,它将删除死代码,包括模块中未使用的函数。请注意,该字段被设计为一个打包函数,它支持在C++和Python中实现优化。

函数级别pass 

函数级pass用于为给定的Relay/tir模块实现各种内部函数级优化。它每次从模块的函数列表中获取一个函数用于优化,并生成一个重写的Relay Function或tir PrimFunc。大多数的pass都可以归为这一类,如Relay中常见的子表达式消除和推理简化,以及tir中的矢量化和扁平化存储等。

注意,这个级别的pass的作用域是一个Relay函数或一个tir原语函数。因此,我们不能通过这些pass添加或删除函数,因为它们不知道全局信息。

class FunctionPassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  bool SkipFunction(const Function& func) const;
  // Other members/methods are omitted...
};

 pass_info与我们刚才在模块pass中描述的一致。pass_func接受一个函数来进行优化,它还需要一个模块,因为我们可能会使用它来报告错误。函数可以注释为“SkipOptimization”,以便在优化期间忽略它。

序列pass

SequentialPass类似于nn. Sequentia,包含一系列顺序执行的pass。

class SequentialPassNode : PassNode {
  PassInfo pass_info;
  // Passes need to be executed.
  Array<Pass> passes;
  bool PassEnabled(const PassInfo& info) const;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};

 目前在Relay中只有少数的pass被放在这个组中。例如,FoldScaleAxis需要内部分派ForwardFoldScaleAxis和BackwardFoldScaleAxis。此外,建议首先完成BackwardFoldScaleAxis。因此,该pass是SequentialPass的理想候选者。

下面的代码展示了如何调用顺序pass中的单个pass。从本质上讲,我们按添加顺序执行序列中每个pass。

Module SequentialNode::operator()(const Module& module,
                                  const PassContext& pass_ctx) const {
  Module mod = module;
  for (const Pass& pass : passes) {
    ICHECK(pass.defined()) << "Found undefined pass for optimization.";
    const PassInfo& pass_info = pass->Info();
    if (!PassEnabled(pass_info))  continue;
    for (const auto& it : pass_info->required) {
      const auto* name = it.as<tvm::ir::StringImm>();
      ICHECK(name);
      mod = GetPass(name->value)(mod, pass_ctx);
    }
    mod = pass(mod, pass_ctx);
  }
  return mod;
}

在调用一个pass时,我们首先检查这个pass是否启用。首先检查pass是否被用户显式禁用,然后检查它是否被用户指定为必需的pass。如果仍然不确定是否启用这个pass,那么将检查它的opt_level。只有当它的优化级别不低于在pass上下文中配置的优化级别时,该pass才会启用并执行。

要执行pass,我们首先需要使用pass名称在TVM打包的函数注册表中检索注册的pass。这是可行的,因为每一个pass都是用一个API端点注册的,我们将在后面展示。

Pass GetPass(const std::string& pass_name) {
  using tvm::runtime::Registry;
  std::string fpass_name = "relay._transform." + pass_name;
  const auto* f = Registry::Get(fpass_name);
  ICHECK(f != nullptr) << "Cannot find " << fpass_name
                      << "to create the pass " << pass_name;
  return (*f)();
}

 一些帮助函数被提供用来来创建上述每种类型的pass。这些帮助函数还被公开到Python前端,以便用户使用Python API创建特定的pass对象。

Pass CreateFunctionPass(
    const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreatePrimFuncPass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreateModulePass(
    const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);

 pass注册

我们已经介绍了不同级别的pass的概念以及编译时使用的上下文。看看用户注册pass有多容易将是一件有趣的事情。让我们以常量折叠为例。这个pass已经实现了在一个Relay函数中折叠常量(可以在src/relay/transforms/fold_constant.cc中找到)。

一个API被提供用来执行Expr到Expr转换。

Expr FoldConstant(const Expr& expr);

为了将这个pass注册到pass基础设施,我们首先需要决定在哪个级别执行该pass。由于常量折叠发生在单个函数上,直观地我们应该通过CreateFunctionPass为它创建一个FunctionPass。pass_func作为一个打包函数返回,它在IRModule中的每个函数上调用Expr 到 Expr API。{}表示此pass不需要任何先决条件。否则,pass开发人员必须识别并列出依赖。

同时,使用relay._transform.FoldConstant名称注册一个pass API端点。这样,这个pass成为注册表中的一个条目,C++(例如上面的GetPass)和Python在需要时都可以访问它。

namespace transform {

Pass FoldConstant() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
      return Downcast<Function>(FoldConstant(f));
  };
  return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}

TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);

}  // namespace transform

为了允许其他c++模块也使用此pass,我们在include/tvm/relay/transform.h中添加函数声明,如下所示:

TVM_DLL Pass FoldConstant();

Pass Instrument

Pass Instrument是一种分析Pass本身的机制。例如,我们可以使用基础设施来知道一个pass需要多少时间和内存,或者一个pass如何转换IR模块。

我们介绍了PassContext生命周期中的四个测量点。

TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;

当进入PassContext实例的作用域时,InstrumentEnterPassContext立即被调用。

当离开PassContext的作用域,或者在pass的执行过程中发生异常时,将调用InstrumentExitPassContext。当在tvm.transform.PassContext中被override_instruments重载时,也会调用此方法。请参阅 Override Instruments in Current PassContext

InstrumentBeforePass在执行之前被调用。如果pass被执行,则InstrumentAfterPass在执行后被调用。行为如下:

if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) {
  new_ir_module = run_pass(ir_module, pass_ctx);
  pass_ctx.InstrumentAfterPass(new_ir_module, pass_info);
  return new_ir_module;
}

 PassInstrument接口允许您在以上四个方法中运行任意代码。可以将多个PassInstrument实例注册到单个PassContext中。PassInstrument实例按照传递给PassContext的instrument参数的顺序被调用。

PassInstrument提供以下接口:

namespace instrument {

class PassInstrumentNode : public Object {
 public:
  String name;
  virtual void EnterPassContext() const = 0;
  virtual void ExitPassContext() const = 0;
  virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0;
  virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0;
  virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;
  /* Other fields are omitted. */
};

class PassInstrument : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode);
};

}  // namespace instrument

Python前端被提供用来快速实现PassInstrument。请参阅Pass Instrument.

在PassContext中,PassInstrument实例的调用序列如下:

with PassContext(instruments=[pi]) # pi = a PassInstrument implementation.
    pi.EnterPassContext()

    if pi.ShouldRun(Pass1):
        pi.RunBeforePass()
        Pass1()
        pi.RunAfterPass()

    if pi.ShouldRun(Pass2):
        pi.RunBeforePass()
        Pass2()
        pi.RunAfterPass()

    pi.ExitPassContext()

下面简要介绍PassInstrument接口和PassContext方法之间的关系。请参阅(src/ir/transform.cc)了解更多细节。

  • InstrumentEnterPassContext

EnterPassContext()是按照传递给PassContext的instrument的顺序执行的。

当发生异常时,PassContext通过清除所有注册的PassInstrument实例来禁用pass测量。

然后PassContext对每个成功完成EnterPassContext()的PassInstrument实例执行ExitPassContext()方法

例如,如果PassInstrument A、B和C被注册到PassContext, A完成了EnterPassContext(),而B抛出异常,那么C永远不会被执行;A的ExitPassContext()被执行。

  • InstrumentExitPassContext

每个PassInstrument实例的ExitPassContext()将按照传递给PassContext的instrument的顺序执行。

当异常发生时,instrument被清除。

在抛出异常之后注册的PassInstrument实例不执行ExitPassContext。

  • InstrumentBeforePass 

如果pass没有被列为必需的pass,则执行ShouldRun。

如果pass没有被ShouldRun阻塞,RunBeforePass将按照instrument的顺序执行。

注意,InstrumentBeforePass返回一个布尔值,指示是否应该运行该pass。

当异常发生时,异常会立即被抛出。我们依赖Python Context Manager安全退出PassContext(这意味着每个instrument的ExitPassContext将被运行。对于C++,请参考include/tvm/support/with.h。)

  •  InstrumentAfterPass

RunAfterPass按照传递给PassContext的instrument的顺序执行。

当异常发生时,异常会立即被抛出。我们依赖于Python Context Manager或With类(include/tvm/support/ with .h)安全地退出PassContext

内建Instrument

有几个内置的instrument。其中标记了TODO的还没有实现。

  • PassTimingInstrument(见src/ir/instrument.cc)

配置pass的执行时间。

  • PrintIRBefore (TODO)

在pass转换IR模块前打印该IR模块。如果在pass周围插入tvm.transform.PrintIR()也可以达到这个目的。但是,使用PassInstrument,我们不需要修改pass的序列。

  • PrintAfter (TODO)

在pass转换IR模块后打印该IR模块。

Python前端 

 前端只需要一些简单的API。例如,我们可以为用户提供以下API来创建和执行一个pass(完整的实现在python/tvm/relay/transform/transform.py和python/tvm/ir/transform.py中提供)。后端接收信息并决定使用哪个函数来创建Pass对象。

PassContext

Python前端通过重载__enter__和__exit__为PassContext提供了一个包装器来启用with语法。为用户提供了一个current静态方法来获取在一定范围内正在使用的上下文。 

@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
    def __enter__(self):
        _transform.EnterPassContext(self)
        return self

    def __exit__(self, ptype, value, trace, config):
        _transform.ExitPassContext(self)

    @staticmethod
    def current():
        """Return the current pass context."""
        return _transform.GetCurrentPassContext()

 PassContext用于配置编译选项,包括优化级别和所需/禁用的pass。它还可以使用一个配置字典,以便不同的pass可以方便地获取传递的数据,例如回退设备信息和循环展开的步骤/深度等。为了能够获取所需的配置,密钥必须通过TVM_REGISTER_PASS_CONFIG_OPTION进行注册。例如,循环展开pass使用以下代码

TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

更多细节请查阅src/tir/transforms/unroll_loop.cc 

pass对象

Pass是所有pass对象的基类。这里的所有方法都只是后端实现的简单包装器。它们是为用户定义的,以便与Python中的基类交互。在pass基类中只定义了__call__,以使子类成为可调用对象,以便它们可以轻松调用(例如pass_xx(arg))执行。

@register_relay_node
class Pass(RelayNode):
   def __call__(self, mod):
       return _transform.RunPass(self, mod)

 一些辅助API被提供以支持从Python前端轻松创建pass,并让pass基础设施控制执行。例如,module_pass、function_pass和sequential被提供给用户,以便用户可以定制自己的pass或pass管道。

对于所有在C++后端实现的pass,我们分别在Python /tvm/ir/transform.py和Python /tvm/relay/transform/transform.py中提供了相应的Python API。例如,常量折叠有一个如下的Python API:

def FoldConstant():
    return _transform.FoldConstant()

用户可以通过装饰器创建一个pass,如下所示:

 @relay.transform.module_pass(opt_level=2)
 def transform(mod, ctx):
    tp = relay.TensorType((10,), "float32")
    x = relay.var("x", tp)
    gv = relay.GlobalVar("abs")
    func = relay.Function([x], relay.abs(x))
    new_mod = tvm.IRModule({gv: func})
    new_mod.update(mod)
    return new_mod

module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2

这里的transform函数向输入模块添加了一个abs函数,但它也可以是模块级的任何定制优化。创建这个module_pass之后,用户可以将它应用到任何Relay模块上。例如,我们可以构建一个空模块,并应用此pass来添加一个abs函数。

mod = tvm.IRModule()
mod = module_pass(mod)

 相应地,我们也为function_pass提供了这样的功能。例如,一个函数级pass的例子可以这样写:

@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
   def __init__(self, new_func):
      self.new_func = new_func
      def transform_function(self, func, mod, ctx):
         # Just for demo purposes
         # Transform func to new_func
         return self.new_func

x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# Now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)

或者,用户也可以直接注册一个pass,而不使用装饰器,然后调用它。定制自己的优化管道,调试Relay和tir通道的更多示例,请参阅use pass infra 。

Pass Instrument

你可以通过在一个实现以下方法的类上,使用pass_instrument装饰器(python/tvm/ir/instrument.py)来实现PassInstrument。注意,建议使用pass_instrument装饰器来实现PassInstrument,而不是重载或子类化。

  • enter_pass_ctx

该方法在进入PassContext时运行。

  • exit_pass_ctx

该方法在退出PassContext时运行。

  • should_run

此方法在执行pass之前运行,返回一个布尔值,指示是否应该运行pass。

  • run_before_pass

如果要运行一个pass,这个方法会在pass执行之前运行。

  • run_after_pass

此方法在执行一个pass之后立即运行。

PassInstrument实例可以通过tvm.transform.PassContext中的instruments参数注册。

use pass instrument教程提供了如何用Python API实现PassInstrument的例子。

在当前PassContext中重载Instrument

override_instruments方法用于覆盖当前PassContext的instrument。例如,如果pass在运行时没有显式地创建一个新的PassContext,仍然可以通过以下方式将PassInstrument注册到全局PassContext:

cur_pass_ctx = tvm.transform.PassContext.current()
# override PassInstrument instances
cur_pass_ctx.override_instruments([pass_inst])
mod = pass_seq(mod)
result = pass_inst.get_result()

 注意,当调用override_instruments时,会调用旧PassInstrument实例的exit_pass_ctx方法。然后调用新的PassInstrument的enter_pass_ctx方法。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值