将编译器转换添加到Relay

14 篇文章 6 订阅

将编译器转换添加到Relay

 

AST遍历

表达式访问器

表达式变换器

示例:常量折叠

【ConstantChecker】访问器

【ConstantFolder】转换器

向转换管理器注册一个Pass


 

编译器转换是扩展Relay功能集和对Relay程序执行优化的主要接口。通过编写编译器转换,可以根据您的目标来修改AST或收集有关AST的信息。确实,Relay的一些最重要的内置功能(例如,autodiff和类型推断)仅是“标准”编译器转换。

总体而言,编写转换有两个关键要素:

  • 创建一个或多个遍历程序的C++类

  • 在转换管理器API中包装遍历实现及其元数据,以便它可以与Relay转换基础架构巧妙地对接

首先,我们将概述编写编译器转换的关键机制。然后,我们将通过Relay进行常数折叠遍历的具体示例。

AST遍历

用于遍历Rel​​ay程序的基类为【ExprFunctor】。它提供的公共接口是一种【VisitExpr】方法,该方法接受一个表达式和零个或多个参数,并返回某种类型的实例。扩展此类时,可以通过为每种类型的表达式重写【VisitExpr_】的实现来定义AST遍历模式。

VisitExpr】和【VisitExpr_】之间的关系与调度有关。每个【VisitExpr_】定义都针对特定的表达式类型,但是您并不总是知道要访问的节点类型。为了解决这个问题,【ExprFunctor】提供了一个VisitExpr函数,该函数将从给定的表达式路由到【VisitExpr_】以便能够处理它。尽管C ++已经提供了动态调度,但【ExprFunctor】仍定义了自己的vtable来给【VisitExpr】使用。通过定义自己的vtable,我们可以更好地控制调度。例如,如果我们想定义一个【PrintVisitor】遍历器,在每次访问之前都打印“ Here” ,则可以覆盖VisitExpr

void PrintVisitor::VisitExpr(const Expr& expr) {
  std::cout << "Here" << std::endl;
  ExprFunctor::VisitExpr(expr);
}

【ExprFunctor】本身是一个非常普通的类,这就是为什么您会经常扩展【ExprVisitor】或【ExprMutator】的原因。这些类扩展【ExprFunctor】并提供【VisitExpr_ 】的默认实现捕获每种表达式类型的常见遍历模式。拥有这些默认实现意味着我们仅需要为需要不同行为的表达式类型提供重写实现。在以下各节中,我们将分别描述每个子类。

表达式访问器

【ExprVisitor】用于转换,它不修改程序而是执行程序分析并收集信息。在此类中,【VisitExpr】 与之相对的部分没有返回值。该类提供的【VisitExpr_ 】实现只访问作为表达式的所有表达式字段。【IfNode】的默认实现如下所示。

void ExprVisitor::VisitExpr_(const IfNode* op) {
  this->VisitExpr(op->cond);
  this->VisitExpr(op->true_branch);
  this->VisitExpr(op->false_branch);
}

请注意,我们在这里调用【VisitExpr】而不是【VisitExpr_】,因此我们可以在【ExprFunctor】中使用vtable 进行路由。

现在,如果我们想编写一个【CallChecker】类来检查程序中是否有任何函数调用,则只需扩展【ExprVisitor】 并定义以下【VisitExpr_】方法:

void VisitExpr_(const CallNode* n) final {
  result_ = true;
}

这里【result_】是一个字段。在这种情况下,我们不需要在的字段上进一步递归【CallNode】,因为【result_】它已经为真,并且我们现在知道原始表达式包含一个调用。为了使此访问器可用,我们将提供以下公共方法:

bool Check(const Expr& expr) final {
  result_ = false;
  VisitExpr(expr);
  return result_;
}

这就是我们所需要的。在调用顶级递归之前,定义一个执行一些簿记工作的公共接口是很常见的。当然,我们可以通过创建一个标准的程序来进一步包装API ,这个程序创建【CallChecker】实例并调用它的【Check】,但值得庆幸的是,我们只需付出很少的努力就可以实现目标。

表达式变换器

【ExprMutator】用于以某种方式变换程序的转换。有了这个类,【VisitExpr】它的私有对应部分就会返回【Expr】。此类提供的【VisitExpr_】默认实现将访问作为表达式的所有表达式的字段,并将这些字段设置为访问它们的结果。【TupleGetItemNode】的默认实现 如下所示。

Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
  auto t = this->Mutate(g->tuple);
  if (g->tuple == t) {
    return GetRef<Expr>(g);
  } else {
    return TupleGetItemNode::make(t, g->index);
  }
}

这里有一些注意事项。首先,【Mutate】是【ExprMutator】里面【VisitExpr】的一个别名 。第二,如果调用【Mutate】修改了【tuple】字段,我们仅返回一个新节点。这种更新方法称为功能更新,这样做可以避免不必要的分配。

ExprMutator】一个功能是对于缓存结果,【ExprVisitor】没有是一个内置【memo_】 字段。它使这个意义上说【ExprMutator】有内存,因为我们知道哪些结果类型我们缓存(即【Expr】),而访问【ExprVisitor】方法不返回任何东西。通常,当我们要缓存【ExprVisitor】子类中的结果时 ,我们需要自己定义缓存。

现在,如果我们想写一个【IfCollapser】类,用为true的分支替换每个if语句,我们会覆盖【IfNode】的【VisitExpr_】为 :

Expr ExprMutator::VisitExpr_(const IfNode* op) {
  return this->Mutate(op->true_branch);
}

请注意,返回的表达式不一定是一个【IfNode】,这很好,因为返回类型是【Expr】。现在,我们创建公共接口:

Expr CollapseIfs(const Expr& expr) final {
  return this->Mutate(expr);
}

使用此转换器,我们不需要进行任何簿记,但是我们仍然希望遵循以描述性方法作为接口的约定。

示例:常量折叠

为了更好地理解编写转换的过程,我们将以常量折叠转换(可在src / relay / pass / fold_constant.cc中找到)作为指导,因为这是一种相对简单的转换,其中包含了两种遍历。

常量折叠涉及在程序中评估仅包含常量值的表达式,然后将这些表达式替换为求值结果。此过程的目标是尽我们所能进行所有计算。为了实现这一点,常量的折叠转换利用了一个访问器(【ConstantChecker】)和一个转换器(【ConstantFolder】)。

ConstantChecker】访问器

此访问器用于检查表达式是否为常数。在Relay中,我们将表达式定义为常数,如果它是一个只有常数字段的【ConstantNode】 或【TupleNode】。

我们使用一个【memo_】字段从节点映射到它们是否是常数并缓存这些结果。以下是【ConstantChecker】中的【VisitExpr_】定义 。

void VisitExpr_(const ConstantNode* n) final {
  memo_[GetRef<Constant>(n)] = true;
}

void VisitExpr_(const TupleNode* n) final {
  bool result = true;
  for (const auto& field : n->fields) {
    if (!Check(field)) {
      result = false;
      break;
    }
  }
  memo_[GetRef<Tuple>(n)] = result;
}

用于协调这些定义的簿记是【Check】,返回给定表达式是否被视为常量的方法。

bool Check(const Expr& expr) {
  const auto it = memo_.find(expr);
  if (it != memo_.end())
    return it->second;
  VisitExpr(expr);
  return memo_[expr];
}

我们不会为遇到的每个节点修改【memo_】;相反,我们仅在遇到的节点可能为常数时修改【memo_】 。然后,当【memo_】不包含【expr】时,我们根据默认值为false 。

ConstantFolder】转换器

该转换器执行大量的常数折叠转换,并在内部使用【ConstantChecker】。在Relay中,有三个节点的类型涉及常量折叠:LetNodeTupleItemGetNode,和 CallNode。在下面的段落中,我们将说明每个过程中的角色。

Expr VisitExpr_(const LetNode* op) final {
  Expr value = this->Mutate(op->value);
  if (value.as<ConstantNode>()) {
    memo_[op->var] = value;
    return this->Mutate(op->body);
  } else {
    Var var = Downcast<Var>(this->Mutate(op->var));
    Expr body = this->Mutate(op->body);
    if (var.same_as(op->var) &&
        value.same_as(op->value) &&
        body.same_as(op->body)) {
      return GetRef<Expr>(op);
    } else {
      return LetNode::make(var, value, body);
    }
  }
}

在【LetNode】情况下,我们首先尝试对表达式中绑定的值进行常量折叠。如果可以的话,我们将填充【memo_】并返回访问的结果-本质上是将绑定值传播到函数体中的使用点。如果我们无法常量折叠绑定值,我们将模仿默认实现。

Expr VisitExpr_(const TupleGetItemNode* op) final {
  Expr res = ExprMutator::VisitExpr_(op);
  op = res.as<TupleGetItemNode>();
  if (const auto* tuple = op->tuple.as<TupleNode>()) {
    return tuple->fields[op->index];
  } else {
    return res;
  }
}

在【TupleItemGetNode】情况下,我们检查【op->tuple】字段是否为【TupleNode】 。如果是这样,我们通过【op->index】用元组所指向的字段替换元组 。我们需要检查的原因是因为【op->tuple】可能会评估为一个元组,而本身不是元组。

Expr VisitExpr_(const CallNode* call) final {
  static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
  Expr res = ExprMutator::VisitExpr_(call);
  call = res.as<CallNode>();
  // We don't constant fold function with zero arguments.
  // This is a heuristic that is useful.
  // For example it is harmful to fold ones(shape=(4, 5)).
  if (call->args.size() == 0) return res;
  const OpNode* op = call->op.as<OpNode>();
  if (op == nullptr) return res;
  // skip stateful ops.
  if (op_stateful.get(GetRef<Op>(op), false)) return res;
  bool all_const_args = true;
  for (Expr arg : call->args) {
    if (!checker_.Check(arg)) {
      all_const_args = false;
    }
  }
  if (all_const_args) {
    return ConstEvaluate(res);
  } else {
    return res;
  }
}

在【CallNode】情况下,我们首先使用【ExprMutator 】中的【VisitExpr_】来访问该调用,该访问将折叠该调用的所有字段。我们使用【ExprMutator::VisitExpr_】 而不是【VisitExpr】,是因为我们要绕过vtable(以避免无限循环)并使用【ExprMutator】提供的默认实现。然后,仅当所有参数都是常量(使用【ConstantChecker】)时,我们才评估调用。评估调用会产生一个value,因此我们使用了一个辅助方法【ValueToExpr】来允许我们将评估后的表达式放回AST中。

现在,我们为常量折叠构造了一个更方便的接口【FoldConstant】。【FoldConstant】是【ConstantFolder 】类之外的一个独立函数,它接受一个表达式并在内部创建并使用一个 【ConstantFolder】实例(完整的定义可以在 src / relay / pass / fold_constant.cc中找到)。

向转换管理器注册一个Pass

注意:有关此主题的更多详细信息,请参见:ref-`relay-pass-infra`上的文档。

编写AST遍历器后,可以使用以下代码将访问权限注册为TVM API端点:

namespace transform {

Pass FoldConstant() {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
      return Downcast<Function>(FoldConstant(f));
  };
  return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}

}  // namespace transform

如果由上述代码产生的【Pass】对象提供给转换基础结构,它将确保AST遍历应用于给定Relay模块中的每个函数,这是人们希望常数折叠转换的行为(它应该折叠所有常量)。

该【CreateFunctionPass 】函数允许为转换注册优化级别(在这种情况下为2),可用于基于流程的通用工具,转换的名称以及转换的任何依赖组合在一起。转换的依赖性作为所有转换的列表给出,其结果对于运行当前转换是必需的。【FoldConstant】没有任何依赖关系,但是许多Relay转换确实依赖于具有类型信息,因此【InferType】是常见的依赖关系;其他的可能依赖于A-normal形式的程序,通过ToANormalForm转换

请注意,【PassContext】对象包含用于错误报告和配置选项的信息;【FoldConstant】不需要此信息,但是其他过程可以引用其【PassContext】对象。

现在可以通过转换基础结构调用转换,尽管最好为转换添加一个Python绑定,如以下代码片段所示:

TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);

一旦【Pass】目标以以上述方式定义了对象,就可以使用转换基础结构的【Sequential】构造来调用它们,该构造将获取转换列表并将其依次应用于Relay模块,从而获得转换后的模块。例如,以下代码将【FoldConstant】和【ToANormalForm】转换(一个接一个)应用于【mod】中的每个函数,并获得一个新模块。

seq = transform.Sequential([
    relay.transform.FoldConstant(),
    relay.transform.ToANormalForm()
])
new_mod = seq(mod)

关于注册的更多详细信息可以在TVM运行系统中找到,有关转换管理器接口的更多信息可以在Relay转换基础结构中找到。Relay的标准转换列在include / tvm / relay / transform.h中,并在`src / relay / pass /`_中实现

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值