在Relay中添加一个编译pass

本文深入探讨了Relay编译器Pass的机制,包括AST遍历、ExprFunctor、ExprVisitor和ExprMutator的使用。通过一个常量折叠Pass的实例,展示了如何分析和修改 Relay 程序。常量折叠Pass涉及检查表达式是否为常量并用计算结果替换,以优化编译阶段。最后,讨论了如何在TVMAPI中注册和使用Pass。
摘要由CSDN通过智能技术生成

本文翻译自Adding a Compiler Pass to Relay — tvm 0.9.dev0 documentation

编译器pass是扩展Relay特性集和对Relay程序进行优化的主要接口。通过编写编译器pass,您可以修改AST或收集有关AST的信息,这取决于您的目标。事实上,Relay的一些最重要的内部特性(例如自动微分和类型推断)只不过是“标准的”编译器pass。

从较高的层次上讲,编写pass有两个关键组件:

  • 创建一个或多个遍历程序的c++类
  • 将遍历实现及其元数据包装在pass管理器API中,以便它能够与pass基础设施灵活地交互

首先,我们将概述编写编译器pass的关键机制。然后,我们将分析一个Relay常量折叠pass的例子。

AST遍历 

用于遍历Relay程序的基类是ExprFunctor。它提供的公共接口是一个VisitExpr方法,该方法接受一个表达式和零个或多个参数,并返回某种类型的实例。当您扩展这个类时,您可以通过重载每种表达式类型的VisitExpr_来定义AST遍历模式。

VisitExpr和VisitExpr_之间的关系与调度有关。每个VisitExpr_的定义都以特定类型的表达式为目标,但您并不总是知道将要访问的节点类型。为了解决这个问题,ExprFunctor提供了一个VisitExpr函数,该函数从给定表达式路由到处理该表达式的VisitExpr_实例。虽然c++已经提供了动态调度,但ExprFunctor定义了自己的虚函数表,以供VisitExpr使用。通过定义自己的虚函数表,我们可以更好地控制分派。例如,如果我们想定义一个PrintVisitor遍历器,在每次访问之前打印" Here ",我们可以重载VisitExpr:

void PrintVisitor::VisitExpr(const Expr& expr) {
  std::cout << "Here" << std::endl;
  ExprFunctor::VisitExpr(expr);
}

ExprFunctor本身是一个非常通用的类,这就是为什么通常需要扩展ExprVisitor或ExprMutator的原因。这些类扩展了ExprFunctor,并提供了VisitExpr_的默认实现,以便按照自己实现的通用遍历模式遍历每种表达式。拥有这些默认实现,意味着我们仅仅为需要特殊处理的表达式类型提供重载实现即可。我们将在下面几节中单独描述每个子类。

表达式访问

ExprVisitor适用于不修改程序,而是执行程序分析和收集信息的pass。使用这个类,VisitExpr和pass私有重载不返回任何东西。这个类提供的VisitExpr_实现只是访问表达式的所有字段。IfNode的默认实现如下所示。

void ExprVisitor::VisitExpr_(const IfNode* op) {
  this->VisitExpr(op->cond);
  this->VisitExpr(op->true_branch);
  this->VisitExpr(op->false_branch);
}

注意,我们在这里调用的是VisitExpr而不是VisitExpr_,所以我们可以使用ExprFunctor中的虚函数表进行路由。

现在,如果我们想编写一个类CallChecker来检查程序中是否出现了任何函数调用,我们只需要扩展ExprVisitor并定义以下VisitExpr_方法:

void VisitExpr_(const CallNode* n) final {
  result_ = true;
}

其中result_是一个类成员变量。在本例中,我们不需要在CallNode的字段上进一步递归,因为result_已经为真,我们现在知道原始表达式包含一个调用。为了使这个访问者可用,我们可以提供以下公共方法:

bool Check(const Expr& expr) final {
  result_ = false;
  VisitExpr(expr);
  return result_;
}

这就是我们所需要的。在调用顶层递归之前,定义一个公共接口执行一些记录是很常见的。当然,我们可以通过创建一个独立的过程来创建一个CallChecker实例并对其调用Check,来进一步包装API,但这里,我们已经用很少的工作实现了我们的目标。

表达式调整器

 ExprMutator用于需要以某种方式转换程序的pass。使用这个类,VisitExpr及其私有对应对象返回Expr实例。这个类提供的默认VisitExpr_实现访问表达式的所有字段,并以这些字段作为返回结果。TupleGetItemNode的默认实现如下所示。

Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
  auto t = this->Mutate(g->tuple);
  if (g->tuple == t) {
    return GetRef<Expr>(g);
  } else {
    return TupleGetItem(t, g->index);
  }
}

这里有一些需要注意的事项。首先,Mutate是ExprMutator中VisitExpr的别名。其次,只有在调用Mutate修改了元组字段时,才返回一个新节点。这种更新方法称为功能更新,这样做可以避免不必要的内存分配。

ExprMutator有一个ExprVisitor不具备的特性,那就是用于缓存结果的内置memo_字段。ExprMutator有一个记忆器是有意义的,因为我们知道缓存的是哪种类型的结果(即Expr),而ExprVisitor的visit方法不返回任何东西。通常,当我们想在ExprVisitor的子类中缓存结果时,需要自己定义缓存。

现在,如果我们想写一个IfCollapser类,用它的true分支替换每个if语句,我们应该为IfNode重写VisitExpr_:

Expr ExprMutator::VisitExpr_(const IfNode* op) {
  return this->Mutate(op->true_branch);
}

注意,返回的表达式不一定是IfNode,这很好,因为返回类型是Expr。现在,我们创建公共接口:

Expr CollapseIfs(const Expr& expr) final {
  return this->Mutate(expr);
}

使用这个mutator,我们不需要做任何记录,但是我们仍然希望遵循使用描述性方法作为接口的惯例。

示例:常量折叠

为了更好地理解编写pass的过程,我们将分析常量折叠pass(见src/relay/transforms/fold_constant.cc),因为它是一个包含了两种类型的遍历,相对简单的pass。

常量折叠涉及对程序中只含常量值的表达式求值,然后用计算结果替换这些表达式。这个pass的目标是在编译阶段尽可能对可以计算出结果的表达式预先计算。为了实现这一点,常量折叠pass使用了一个访问器(ConstantChecker)和一个调整器 (ConstantFolder)。

访问器ConstantChecker

此访问器用于检查表达式是否为常量。在Relay中,如果表达式是ConstantNode或者是只有常量字段的TupleNode,则该表达式为常量。

我们使用一个memo_字段将节点映射到它们是否为常量,并缓存这些结果。下面是ConstantChecker中的VisitExpr_实现。

void VisitExpr_(const ConstantNode* n) final {
  memo_[GetRef<Constant>(n)] = true;
}

void VisitExpr_(const TupleNode* n) final {
  bool result = true;
  for (const auto& field : n->fields) {
    if (!Check(field)) {
      result = false;
      break;
    }
  }
  memo_[GetRef<Tuple>(n)] = result;
}

Check方法返回给定表达式是否被认为是常量。

bool Check(const Expr& expr) {
  const auto it = memo_.find(expr);
  if (it != memo_.end())
    return it->second;
  VisitExpr(expr);
  return memo_[expr];
}

我们不会为遇到的每个节点修改memo_;相反,我们只在遇到可能是常量的节点时才修改memo_。当memo_不包含expr时默认为false。

ConstantFolder调整器

这个mutator执行常量折叠pass的大部分流程,并在内部使用ConstantChecker。在Relay中,有三种节点类型涉及到常量折叠:LetNode、TupleItemGetNode和CallNode。在接下来的段落中,我们将解释它们在pass中的角色。 

Expr VisitExpr_(const LetNode* op) final {
  Expr value = this->Mutate(op->value);
  if (value.as<ConstantNode>()) {
    memo_[op->var] = value;
    return this->Mutate(op->body);
  } else {
    Var var = Downcast<Var>(this->Mutate(op->var));
    Expr body = this->Mutate(op->body);
    if (var.same_as(op->var) &&
        value.same_as(op->value) &&
        body.same_as(op->body)) {
      return GetRef<Expr>(op);
    } else {
      return Let(var, value, body);
    }
  }
}

在LetNode实例中,我们首先尝试对表达式中绑定的value进行常量折叠。如果可以,那么我们填写memo_并返回访问body的结果。访问body本质上是将绑定的value值传播到它在body中的使用的地方。如果不能对绑定值进行常量折叠,则返回一个等同的Let表达式。

Expr VisitExpr_(const TupleGetItemNode* op) final {
  Expr res = ExprMutator::VisitExpr_(op);
  op = res.as<TupleGetItemNode>();
  if (const auto* tuple = op->tuple.as<TupleNode>()) {
    return tuple->fields[op->index];
  } else {
    return res;
  }
}

在TupleItemGetNode实例中,我们检查op->tuple字段是否为TupleNode。如果是,则返回op->index指向的元组字段。这里我们检查是否为TupleNode的原因是, op->tuple可能被认为是一个元组,单它本身不是元组。

Expr VisitExpr_(const CallNode* call) final {
  static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
  Expr res = ExprMutator::VisitExpr_(call);
  call = res.as<CallNode>();
  // We don't constant fold function with zero arguments.
  // This is a heuristic that is useful.
  // For example it is harmful to fold ones(shape=(4, 5)).
  if (call->args.size() == 0) return res;
  const OpNode* op = call->op.as<OpNode>();
  if (op == nullptr) return res;
  // skip stateful ops.
  if (op_stateful.get(GetRef<Op>(op), false)) return res;
  bool all_const_args = true;
  for (Expr arg : call->args) {
    if (!checker_.Check(arg)) {
      all_const_args = false;
    }
  }
  if (all_const_args) {
    return ConstEvaluate(res);
  } else {
    return res;
  }
}

在CallNode实例中,我们首先使用ExprMutator的VisitExpr_访问函数调用,对该调用的所有参数字段进行常量折叠。我们使用ExprMutator::VisitExpr_而不是VisitExpr,是因为我们想绕过虚函数表(以避免无限循环),而使用ExprMutator提供的默认实现。然后,只有当所有参数都是常量时(使用ConstantChecker)才计算调用。对调用求值会产生一个值,因此我们使用helper方法ValueToExpr来允许我们将求值表达式放回AST中。

现在,我们为常量折叠构造一个更方便的接口FoldConstant。FoldConstant是一个独立于ConstantFolder类之外的函数,它接受一个表达式,并在内部创建和使用ConstantFolder实例(完整的定义可以在src/relay/transforms/fold_constant.cc中找到)。

使用pass管理器注册pass 

写好AST遍历器后,可以通过以下代码将pass注册为TVM API:

namespace transform {

Pass FoldConstant() {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
      return Downcast<Function>(FoldConstant(f));
  };
  return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}

}  // namespace transform

如果由上面的代码产生的Pass对象被传递给pass基础设施,它将确保AST遍历应用到给定Relay模块中的每一个函数,这是人们预期中常数折叠pass(它应该尽可能折叠所有常量)。

函数CreateFunctionPass允许注册pass的优化级别(在本例中是2),这个级别可以用于根据pass功能、名称和依赖关系将pass分组。一个pass的依赖pass以列表形式给出,这些依赖pass的运行结果是当前pass的必要条件。FoldConstant没有任何依赖,但是许多Relay pass确实依赖于类型信息,所以InferType是一个常见的依赖;其他的可能依赖于程序的正规l形式,程序的正规形式可以通过ToANormalForm pass得到。

注意PassContext对象包含一个pass用于报告错误和配置选项的信息;FoldConstant不需要这些信息,但其他pass可能会引用它们的PassContext对象。

现在pass可以通过pass基础架构调用,不过最好也为pass添加一个Python绑定,如下所示:

TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);

一旦以上述方式定义了Pass对象,就可以使用Pass基础结构的Sequential结构来调用它们,该结构接受一个pass列表,并将它们按顺序应用到一个Relay模块,从而获得一个转换后的模块。例如,下面的代码将FoldConstant和ToANormalForm pass(一个接一个)应用到模块中的每个函数,并获得一个新模块。

seq = transform.Sequential([
    relay.transform.FoldConstant(),
    relay.transform.ToANormalForm()
])
new_mod = seq(mod)

关于注册的更多细节参见TVM Runtime System,关于pass管理器接口的更多信息参见Pass Infrastructure。Rlay的标准pass列在include/tvm/relay/transform.h中,并在src/relay/transforms/中实现。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值