本文翻译自Adding a Compiler Pass to Relay — tvm 0.9.dev0 documentation
编译器pass是扩展Relay特性集和对Relay程序进行优化的主要接口。通过编写编译器pass,您可以修改AST或收集有关AST的信息,这取决于您的目标。事实上,Relay的一些最重要的内部特性(例如自动微分和类型推断)只不过是“标准的”编译器pass。
从较高的层次上讲,编写pass有两个关键组件:
- 创建一个或多个遍历程序的c++类
- 将遍历实现及其元数据包装在pass管理器API中,以便它能够与pass基础设施灵活地交互
首先,我们将概述编写编译器pass的关键机制。然后,我们将分析一个Relay常量折叠pass的例子。
AST遍历
用于遍历Relay程序的基类是ExprFunctor。它提供的公共接口是一个VisitExpr方法,该方法接受一个表达式和零个或多个参数,并返回某种类型的实例。当您扩展这个类时,您可以通过重载每种表达式类型的VisitExpr_来定义AST遍历模式。
VisitExpr和VisitExpr_之间的关系与调度有关。每个VisitExpr_的定义都以特定类型的表达式为目标,但您并不总是知道将要访问的节点类型。为了解决这个问题,ExprFunctor提供了一个VisitExpr函数,该函数从给定表达式路由到处理该表达式的VisitExpr_实例。虽然c++已经提供了动态调度,但ExprFunctor定义了自己的虚函数表,以供VisitExpr使用。通过定义自己的虚函数表,我们可以更好地控制分派。例如,如果我们想定义一个PrintVisitor遍历器,在每次访问之前打印" Here ",我们可以重载VisitExpr:
void PrintVisitor::VisitExpr(const Expr& expr) {
std::cout << "Here" << std::endl;
ExprFunctor::VisitExpr(expr);
}
ExprFunctor本身是一个非常通用的类,这就是为什么通常需要扩展ExprVisitor或ExprMutator的原因。这些类扩展了ExprFunctor,并提供了VisitExpr_的默认实现,以便按照自己实现的通用遍历模式遍历每种表达式。拥有这些默认实现,意味着我们仅仅为需要特殊处理的表达式类型提供重载实现即可。我们将在下面几节中单独描述每个子类。
表达式访问
ExprVisitor适用于不修改程序,而是执行程序分析和收集信息的pass。使用这个类,VisitExpr和pass私有重载不返回任何东西。这个类提供的VisitExpr_实现只是访问表达式的所有字段。IfNode的默认实现如下所示。
void ExprVisitor::VisitExpr_(const IfNode* op) {
this->VisitExpr(op->cond);
this->VisitExpr(op->true_branch);
this->VisitExpr(op->false_branch);
}
注意,我们在这里调用的是VisitExpr而不是VisitExpr_,所以我们可以使用ExprFunctor中的虚函数表进行路由。
现在,如果我们想编写一个类CallChecker来检查程序中是否出现了任何函数调用,我们只需要扩展ExprVisitor并定义以下VisitExpr_方法:
void VisitExpr_(const CallNode* n) final {
result_ = true;
}
其中result_是一个类成员变量。在本例中,我们不需要在CallNode的字段上进一步递归,因为result_已经为真,我们现在知道原始表达式包含一个调用。为了使这个访问者可用,我们可以提供以下公共方法:
bool Check(const Expr& expr) final {
result_ = false;
VisitExpr(expr);
return result_;
}
这就是我们所需要的。在调用顶层递归之前,定义一个公共接口执行一些记录是很常见的。当然,我们可以通过创建一个独立的过程来创建一个CallChecker实例并对其调用Check,来进一步包装API,但这里,我们已经用很少的工作实现了我们的目标。
表达式调整器
ExprMutator用于需要以某种方式转换程序的pass。使用这个类,VisitExpr及其私有对应对象返回Expr实例。这个类提供的默认VisitExpr_实现访问表达式的所有字段,并以这些字段作为返回结果。TupleGetItemNode的默认实现如下所示。
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItem(t, g->index);
}
}
这里有一些需要注意的事项。首先,Mutate是ExprMutator中VisitExpr的别名。其次,只有在调用Mutate修改了元组字段时,才返回一个新节点。这种更新方法称为功能更新,这样做可以避免不必要的内存分配。
ExprMutator有一个ExprVisitor不具备的特性,那就是用于缓存结果的内置memo_字段。ExprMutator有一个记忆器是有意义的,因为我们知道缓存的是哪种类型的结果(即Expr),而ExprVisitor的visit方法不返回任何东西。通常,当我们想在ExprVisitor的子类中缓存结果时,需要自己定义缓存。
现在,如果我们想写一个IfCollapser类,用它的true分支替换每个if语句,我们应该为IfNode重写VisitExpr_:
Expr ExprMutator::VisitExpr_(const IfNode* op) {
return this->Mutate(op->true_branch);
}
注意,返回的表达式不一定是IfNode,这很好,因为返回类型是Expr。现在,我们创建公共接口:
Expr CollapseIfs(const Expr& expr) final {
return this->Mutate(expr);
}
使用这个mutator,我们不需要做任何记录,但是我们仍然希望遵循使用描述性方法作为接口的惯例。
示例:常量折叠
为了更好地理解编写pass的过程,我们将分析常量折叠pass(见src/relay/transforms/fold_constant.cc),因为它是一个包含了两种类型的遍历,相对简单的pass。
常量折叠涉及对程序中只含常量值的表达式求值,然后用计算结果替换这些表达式。这个pass的目标是在编译阶段尽可能对可以计算出结果的表达式预先计算。为了实现这一点,常量折叠pass使用了一个访问器(ConstantChecker)和一个调整器 (ConstantFolder)。
访问器ConstantChecker
此访问器用于检查表达式是否为常量。在Relay中,如果表达式是ConstantNode或者是只有常量字段的TupleNode,则该表达式为常量。
我们使用一个memo_字段将节点映射到它们是否为常量,并缓存这些结果。下面是ConstantChecker中的VisitExpr_实现。
void VisitExpr_(const ConstantNode* n) final {
memo_[GetRef<Constant>(n)] = true;
}
void VisitExpr_(const TupleNode* n) final {
bool result = true;
for (const auto& field : n->fields) {
if (!Check(field)) {
result = false;
break;
}
}
memo_[GetRef<Tuple>(n)] = result;
}
Check方法返回给定表达式是否被认为是常量。
bool Check(const Expr& expr) {
const auto it = memo_.find(expr);
if (it != memo_.end())
return it->second;
VisitExpr(expr);
return memo_[expr];
}
我们不会为遇到的每个节点修改memo_;相反,我们只在遇到可能是常量的节点时才修改memo_。当memo_不包含expr时默认为false。
ConstantFolder调整器
这个mutator执行常量折叠pass的大部分流程,并在内部使用ConstantChecker。在Relay中,有三种节点类型涉及到常量折叠:LetNode、TupleItemGetNode和CallNode。在接下来的段落中,我们将解释它们在pass中的角色。
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
memo_[op->var] = value;
return this->Mutate(op->body);
} else {
Var var = Downcast<Var>(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Let(var, value, body);
}
}
}
在LetNode实例中,我们首先尝试对表达式中绑定的value进行常量折叠。如果可以,那么我们填写memo_并返回访问body的结果。访问body本质上是将绑定的value值传播到它在body中的使用的地方。如果不能对绑定值进行常量折叠,则返回一个等同的Let表达式。
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as<TupleGetItemNode>();
if (const auto* tuple = op->tuple.as<TupleNode>()) {
return tuple->fields[op->index];
} else {
return res;
}
}
在TupleItemGetNode实例中,我们检查op->tuple字段是否为TupleNode。如果是,则返回op->index指向的元组字段。这里我们检查是否为TupleNode的原因是, op->tuple可能被认为是一个元组,单它本身不是元组。
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
// We don't constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
bool all_const_args = true;
for (Expr arg : call->args) {
if (!checker_.Check(arg)) {
all_const_args = false;
}
}
if (all_const_args) {
return ConstEvaluate(res);
} else {
return res;
}
}
在CallNode实例中,我们首先使用ExprMutator的VisitExpr_访问函数调用,对该调用的所有参数字段进行常量折叠。我们使用ExprMutator::VisitExpr_而不是VisitExpr,是因为我们想绕过虚函数表(以避免无限循环),而使用ExprMutator提供的默认实现。然后,只有当所有参数都是常量时(使用ConstantChecker)才计算调用。对调用求值会产生一个值,因此我们使用helper方法ValueToExpr来允许我们将求值表达式放回AST中。
现在,我们为常量折叠构造一个更方便的接口FoldConstant。FoldConstant是一个独立于ConstantFolder类之外的函数,它接受一个表达式,并在内部创建和使用ConstantFolder实例(完整的定义可以在src/relay/transforms/fold_constant.cc中找到)。
使用pass管理器注册pass
写好AST遍历器后,可以通过以下代码将pass注册为TVM API:
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
} // namespace transform
如果由上面的代码产生的Pass对象被传递给pass基础设施,它将确保AST遍历应用到给定Relay模块中的每一个函数,这是人们预期中常数折叠pass(它应该尽可能折叠所有常量)。
函数CreateFunctionPass允许注册pass的优化级别(在本例中是2),这个级别可以用于根据pass功能、名称和依赖关系将pass分组。一个pass的依赖pass以列表形式给出,这些依赖pass的运行结果是当前pass的必要条件。FoldConstant没有任何依赖,但是许多Relay pass确实依赖于类型信息,所以InferType是一个常见的依赖;其他的可能依赖于程序的正规l形式,程序的正规形式可以通过ToANormalForm pass得到。
注意PassContext对象包含一个pass用于报告错误和配置选项的信息;FoldConstant不需要这些信息,但其他pass可能会引用它们的PassContext对象。
现在pass可以通过pass基础架构调用,不过最好也为pass添加一个Python绑定,如下所示:
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
一旦以上述方式定义了Pass对象,就可以使用Pass基础结构的Sequential结构来调用它们,该结构接受一个pass列表,并将它们按顺序应用到一个Relay模块,从而获得一个转换后的模块。例如,下面的代码将FoldConstant和ToANormalForm pass(一个接一个)应用到模块中的每个函数,并获得一个新模块。
seq = transform.Sequential([
relay.transform.FoldConstant(),
relay.transform.ToANormalForm()
])
new_mod = seq(mod)
关于注册的更多细节参见TVM Runtime System,关于pass管理器接口的更多信息参见Pass Infrastructure。Rlay的标准pass列在include/tvm/relay/transform.h中,并在src/relay/transforms/中实现。