TVM TIR Pass - CSE (Common Subexpression Elimination) 优化原理和代码解析 PR#9482

理论

预备知识

  • TIR Let Binding
    • Let (var, value, body) 将value求值赋给var,然后返回body的求值结果。let将表达式 Expr 绑定到局部作用域的不可变 var 中。
  • scope:作用域/代码块。 TVM AST将作用域组建为树形结构。外侧作用域包含内侧作用域的关系被表示为父节点和子节点。父节点的变量对子节点可见但反之不然。判断变量是否在某个作用域内是CSE算法的一个重要的部分。

是什么

TIR代码来源于relay和其他Pass。生成TIR代码的过程是自动的,因此会有很多重复。Common Subexpression Elimination (CSE,公共子表达式消除) 是TIR的Pass之一,旨在定位并替换重复的计算。

  • 创建一个新变量并替换所有的表达式。
  • 支持完整的表达式替换。
  • 支持子表达式替换。比如(w+x)+(y+z); (w+x)+u; => new_var = (w+x); new_var+(y+z); new_var+u;(w+x)就是子表达式。

原理

前提

TIR的SSA(Static Single Assignment)性质:变量的值不变(immutable)。如果没有这个前提,替换就会出问题。比如拿y=a+b替换所有的a+b,但是a的值在某处被修改了,那么之后的y=a+b就变成了y=0+b,如果不重新计算a+b的值,就会有错。

定位重复的表达式

筛选出候选的表达式

  • 表达式不是常量或变量(已经是变量了就没必要创一个新变量去替换了)
  • 表达式不是function call或者memory load
    • 函数不一定是pure的。对于有副作用(side effects)的函数,即使函数名和参数一样,返回的结果可能不同。
    • 同理,两次memeory load返回的结果可能不同。如果两个相同的表达式中间出现了一次memory load,则这个表达式不能作为候选表达式。
  • 表达式也不包含(子)function call或者memory load。
    • 替换sum(f,f)也是不安全的,因为f可能有side effects。

判断候选表达式是否应该继续处理
表达式所使用的变量必须在当前的scope下频繁出现。

对于不满足以上条件的表达式,递归地考虑它的子表达式。
比如(w+x)*(y+z)包含的(w+x)(y+z)

数据结构

  • Context:上下文,vector<pair<Var,MaybeValue>>
    • 知道哪个变量在当前的scope下
  • table of computations:表达式计数表,unordered_map。
    • key是表达式PrimExpr,例如在Stmtbuffer[i1] = ((x + y) + z)中,((x + y) + z)(x+y)都是PrimExpr,后者可以由Visit而其子表达式
    • value是其出现的次数。

创建新变量
考虑到新变量的表达式之间可能会有包含关系,需要将表示长表达式的新变量放在在let作用域内部。短的表达式的变量则在外部。

如图,外侧smallComp可以是y=a+b, z=d+e,内侧bigComp则可以是p=y+z。

缺点与改进

  • 未来可以支持丰富的语义结构。比如(x+y)+z <=> z+(x+y)
  • 区分出side effects的函数,以便进行更深的优化。

参考资料

TVM Conference 2021 Qualcomm
TVM 拆包(一):Runtime basics

代码实现

先来看tests/python/unittest/test_tir_transform_common_subexpr_elim.py::test_cse:

@main = primfn(i1: int32, i2: int32, z3: int32) -> () {
  let z1: int32 = 1
  let z2: int32 = 2
   {
    buffer: Pointer(int32)[i1] = (z1 + z2)
    let x: int32 = 1
    let y: int32 = 1
    let a: int32 = ((x + y) + (z1 + z2))
    let b: int32 = ((x + y) + z3)
    buffer[i2] = (a + b)
  }
}

[15:01:55] /home/yuan/Coding/compiler/repos/tvm/src/ir/transform.cc:566: PrintIR():
#[version = "0.0.5"]
@main = primfn(i1: int32, i2: int32, z3: int32) -> () {
  let z1: int32 = 1
  let z2: int32 = 2
  let cse_var_1: int32 = (z1 + z2)
   {
    buffer: Pointer(int32)[i1] = cse_var_1
    let x: int32 = 1
    let y: int32 = 1
    let cse_var_2: int32 = (x + y)
    let a: int32 = (cse_var_2 + cse_var_1)
    let b: int32 = (cse_var_2 + z3)
    buffer[i2] = (a + b)
  }
}

观察到,CSE生成了两个变量cse_var_1=z1+z2, cse_var_2=x+y,代替了相应的表达式。替换规则:z1+z2,x+y两次出现在同一scope下,且没有关于变量的load操作。

test_tir_transform_common_subexpr_elim.py::test_cse_cascade:

yuan@yuan:~/Coding/compiler/repos/tvm$ python -m pytest /home/yuan/Coding/compiler/repos/tvm/tests/python/unittest/test_tir_transform_common_subexpr_elim.py::test_cse_cascade -s
enabled targets: llvm; cuda; nvptx
pytest marker: 
====================================================================== test session starts ======================================================================
platform linux -- Python 3.8.10, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /home/yuan/Coding/compiler/repos/tvm
collected 1 item                                                                                                                                                

tests/python/unittest/test_tir_transform_common_subexpr_elim.py @main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {
  buffer: Pointer(int32)[i1] = ((x + y) + z)
  buffer[i2] = ((x + y) + z)
  buffer[i3] = (x + y)
}

[15:17:37] /home/yuan/Coding/compiler/repos/tvm/src/ir/transform.cc:566: PrintIR():
#[version = "0.0.5"]
@main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {
  let cse_var_2: int32 = (x + y)
  let cse_var_1: int32 = (cse_var_2 + z)
   {
    buffer: Pointer(int32)[i1] = cse_var_1
    buffer[i2] = cse_var_1
    buffer[i3] = cse_var_2
  }
}

替换规则对应了之前讲到的内容:长表达式在内,短表达式在外。

看完了样例,我们对照第二个样例分析源码的执行过程。在递归入口,函数Input和Output,计数哈希表处插入日志来观察调用逻辑:

yuan@yuan:~/Coding/compiler/repos/tvm$ python -m pytest /home/yuan/Coding/compiler/repos/tvm/tests/python/unittest/test_tir_transform_common_subexpr_elim.py::test_cse_cascade -s
enabled targets: llvm; cuda; nvptx
pytest marker: 
============================================================= test session starts ==============================================================
platform linux -- Python 3.8.10, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /home/yuan/Coding/compiler/repos/tvm
collected 1 item                                                                                                                               

tests/python/unittest/test_tir_transform_common_subexpr_elim.py @main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {
  buffer: Pointer(int32)[i1] = ((x + y) + z)
  buffer[i2] = ((x + y) + z)
  buffer[i3] = (x + y)
}

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/runtime/logging.cc:239: TVM_LOG_DEBUG enables VLOG statements in 'tir/transforms/common_subexpr_elim.cc' up to level 1
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt : 
buffer[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
(((x + y) + z), 2)
((x + y), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:493: variables_created true
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt : 
let cse_var_1 = ((x + y) + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = (x + y)

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
(((x + y) + z), 1)
((x + y), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:493: variables_created true
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt : 
let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
((x + y), 1)
((cse_var_2 + z), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt : 
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
((cse_var_2 + z), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt : 
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt : 
buffer[i1] = cse_var_1

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i1] = cse_var_1

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt : 
buffer[i2] = cse_var_1

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i2] = cse_var_1

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt : 
buffer[i3] = cse_var_2

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i3] = cse_var_2

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2

[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/ir/transform.cc:566: PrintIR():
#[version = "0.0.5"]
@main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {
  let cse_var_2: int32 = (x + y)
  let cse_var_1: int32 = (cse_var_2 + z)
   {
    buffer: Pointer(int32)[i1] = cse_var_1
    buffer[i2] = cse_var_1
    buffer[i3] = cse_var_2
  }
}

执行流程图
按照Log的输出,模拟了一个执行流程。方框内包含了原始的计算图。按照从下往上的顺序,依次加入了cse_var_1, cse_var_2。而在遍历子节点时,又从上往下使用DFS。

Context更新/调用入口

PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
  // At this point, we have already done the generic treatment of introducing (via let-in) what
  // was doable at the toplevel of the given let-in.

  // Save the context at the entry of the function
  Context context_at_entry = context_;

  // Recurse on the `value` field for potentially rewriting it
  PrimExpr value_new = VisitExpr(op->value);

  // Augment the context with the association (`var`, `value`) for preparing the next recursion
  // on the `body`
  context_.push_back({op->var, MaybeValue(op->value)});

  // Recurse on the `body` (with this extended context)
  // The recursive call will have potentially done new simplifications, because in this recursive
  // call `var` will be a part of the context.
  // (see in VisitExpr() that no introduction were performed when a computation was using an
  // undefined variable, as that would lead to ill-formed code)
  PrimExpr body_new = VisitExpr(op->body);

  // Restaure the context to its content at the entrance to not carry out of scope declarations
  // as the variable introduced by the let-in is not in scope outside of its body
  context_ = context_at_entry;

  // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
  // have been done.

  // If the `value` and the `body` of the let-in have been rewritten to the same thing
  if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
    // then return a reference to the same node
    return GetRef<PrimExpr>(op);
  } else {
    // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
    // have just been obtained
    return Let(op->var, value_new, body_new, op->span);
  }
}

CommonSubexpressionEliminator::VisitExpr_(const LetNode* op)作为调用VisitExpr的函数入口,做了两件事:

  • 先递归遍历Var
  • 将let表达式包含的变量加入上下文,准备下一步递归遍历body
  • 递归遍历body,此时let包含的变量存储在context中,因此在body中可见
  • 还原context

候选表达式规则

bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
  return (
      // In order to be eligible, the given expression should not be a constant
      (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
      (expr.as<StringImmNode>() == nullptr)
      // and it should not be a variable
      && (expr.as<VarNode>() == nullptr)
      // and it should not be a forbidden computation (function calls and loads)
      && (!ForbiddenComputation(expr))
      // and it should not even contain a forbidden computation (function calls and loads)
      // the reason is that we don't want to register expressions like (x + f(y)) or
      // (x + Mem[i]) as introducing them into variables could change the semantics
      && (!CheckContains::ExprContains(expr, ForbiddenComputation))
      // and it should not be a ramp node or a broadcast node due to some internals TVM
      // constraints (which check for these node explicitely without performing any
      // evaluation first, so if they have been put into variables it fails)
      && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
}

关于ComputationTable

  // Obtain the (syntactic) eligible computations done by the input statement, and keep it as
  // a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the
  // number of time this exact syntactic computation is being computed.
  ComputationTable table_syntactic_comp_done_by_stmt = ComputationsDoneBy::GetComputationsDoneBy(
      stmt, IsEligibleComputation, CanContainEligibleComputations);	
  ...
  std::unordered_map<Stmt, ComputationTable, ObjectPtrHash, ObjectPtrEqual>
      cache_stmt_table_computations_;
  ...
  void ComputationsDoneBy::VisitStmt(const Stmt& stmt) {
	 // See if we have already computed the (table of) computations done by `stmt`
	 auto it_table_stmt = cache_.cache_stmt_table_computations_.find(stmt);
	 if (it_table_stmt != cache_.cache_stmt_table_computations_.end()) {
	   // We need to do the union with `table_of_computations_` instead of just writing into it,
	   // because some other childs might have added things into it too. The reason for that is
	   // that `table_of_computations_` is shared between the child nodes of a given statement.
	   UnionOfComputationTables(&table_of_computations_, it_table_stmt->second);
	   return;
	 }
	
	 // If we reach this point, it means that we have never computed before the computations done
	 // by `stmt` and will do so now.
	
	 // The computations done by a Stmt node are just the ones done by its children
	 ComputationTable temp =
	     ComputationsDoneByChildrenOf(stmt, is_eligible_computation_, can_contain_computations_);
	 // We need to do the union with `table_of_computations_` instead of just writing into it,
	 // because some other childs might have added things into it too. The reason for that is
	 // that `table_of_computations_` is shared between the child nodes of a given expression.
	 UnionOfComputationTables(&table_of_computations_, temp);
  }

 

对于每个Stmt,程序保存一个unordered_map<Stmt, ComputationTable>。使用时,拿着Stmt查找对应的ComputationTable。

ComputationTable ComputationsDoneBy::ComputationsDoneByChildrenOf(
    const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation,
    std::function<bool(const PrimExpr&)> can_contain_computations) {
  // We will be using an instance of the class ComputationsDoneBy for the child nodes
  // (ie, they will share the "result" that `table_of_computations_` is)
  ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations);
  // Calls the *dispatcher* (not the overriden method)
  computations_done_by.StmtExprVisitor::VisitStmt(stmt);
  // So now we can copy table_of_computations_ into the cache for the future queries
  // Note : in the table, the computations done by `stmt` is set to the computations done by its
  // children, because that's exactly what we mean by "the computations of a statement".
  cache_.cache_stmt_table_computations_[stmt] = computations_done_by.table_of_computations_;

  return computations_done_by.table_of_computations_;
}

ComputationsDoneByChildrenOf和ComputationsDoneBy::VisitStmt实际上是互相递归调用的,因为

  // Calls the *dispatcher* (not the overriden method)
  computations_done_by.StmtExprVisitor::VisitStmt(stmt);


StmtExprVisitor::VisitStmt的作用是递归地访问Stmts和它的表达式。
可以模拟执行顺序:
ComputationsDoneBy::VisitStmt 传入Stmt -> ComputationsDoneByChildrenOf -> computations_done_by.StmtExprVisitor::VisitStmt(stmt) 传入下一个Stmt -> ComputationsDoneBy::VisitStmt -> …

为了方便理解,在ComputationsDoneByChildrenOf里打印日志:

[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt : 
buffer[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)

[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)

[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i1] = ((x + y) + z)

[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i2] = ((x + y) + z)

[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i3] = (x + y)

[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
(((x + y) + z), 2)
((x + y), 1)
}

扩展:语法转语义
在之前的“缺点与改进“中提到。SyntacticToSemanticComputations和EquivalentTerms等等是单独的一个模组。

目的是支持:

  • 交换律 (x+y <=> y+x)
  • 结合律 (x+y)+z <=> x+(y+z)
  • 分配律 x*(y+z) <=> x*y+x*z
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
// containing *semantic* entities, i.e. where equivalent computations are merged.
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =
    SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt);
...
/*!
 * \brief Decides if two terms are equivalent semantically
 */
bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) {
  // For now, we just check the syntactic equality, but that could later become a semantic test,
  // for instance identifying computations modulo commutativity (like x+y and y+x), or modulo
  // associativity (like (x+y)+z and x+(y+z)), etc.
  arith::Analyzer analyser;
  PrimExpr a_simplified = analyser.Simplify(a);
  PrimExpr b_simplified = analyser.Simplify(b);
  return EqualTerms(a_simplified, b_simplified);
}

上面是我在EquivalentTerms上做的改动,tvm/arith下支持一部分语义分析。我也和开发者讨论过一次。结果是虽然不能完全覆盖所有情况但是聊胜于无。

创建新变量
按照从长到短的规则,对数据结构按照其长度(Complexity)降序排序。排序后遍历,并用Let语句包含。长的表达式会在内侧,然后被后来添加的短表达式所使用的Let语句块覆盖。

  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
              return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
            });
  for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
	    std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
	    ...
		Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
		...
		result = Let(new_var, computation_and_nb.first, result);
	}

用新变量替换所有Occurence

// Replace in the current `result` everything that is selected by the selector with
// the new variable, without diving into expressions in which we don't have the
// right to dive.
result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr(result, predicate_selector, new_var,
                                                        CanContainEligibleComputations);
...
class ReplaceSelectedExpr : public StmtExprMutator
...
PrimExpr ReplaceSelectedExpr::VisitExpr(const PrimExpr& expr) {
  // If the current expression is selected by the predicate
  if (predicate_selector_(expr)) {
    // Then simply return the new expression
    return new_expr_;
  } else {
    // If replacing inside the current expression is allowed
    if (can_replace_inside_(expr)) {
      // then we continue the exploration recursively
      return StmtExprMutator::VisitExpr(expr);
    } else {
      // otherwise we simply return the current expression
      return expr;
    }
  }
}

分清况讨论:

  • 如果当前表达式被predicate_selector_选中,则返回new_expr_。这里完成了代替过程。
  • 如果当前表达式不被predicate_selector_选中,StmtExprMutator::VisitExpr(expr);将对其子表达式递归调用ReplaceSelectedExpr::VisitExpr。ComputationTable中也有类似的调用逻辑。
  • 否则不做改动,返回expr自己。

扩展:pure属性的函数
最早的Commit禁止了函数的优化。原因在之前已经说过了。但是更深的优化可以利用函数是否纯,来决定是否可以进行替换。原作者FrankQC认为,决定优化的规则有两个,一是函数是否对同一组参数有相同的输出,二是函数是否修改了外部状态。在笔者写这篇文章的时候,函数的性质可以注册为:

enum class CallEffectKind : int {
  /*! \brief Function corresponds to an annotation(e.g. likely) and can translate to identity. */
  kExprAnnotation = 0,
  /*!
   * \brief Pure function that do not interacts
   *        with any external state.
   */
  kPure = 1,
  /*!
   * \brief Function's that may read from states(e.g. RAM)
   */
  kReadState = 2,
  /*!
   * \brief Function that may read/write from states(e.g. RAM).
   */
  kUpdateState = 3,
  /*!
   * \brief Opaque function, cannot make any assumption
   */
  kOpaque = kUpdateState,
  /*!
   * \brief Special intrinsic to annotate call arguments info
   *        only valid as a direct argument to a call.
   */
  kSpecialCallArg = 4,
  /*!
   * \brief Embed opaque information in the Expr, cannot be codegen.
   */
  kEmbedInfo = 5,
  /*!
   * \brief Function that changes control flow
   */
  kControlJump = 6,
};

被注册为kPure=1的函数,其语义为“do not interacts with any external state"。故不能草率的认为系统内所有kPure=1的函数均为纯。如果要引入新的性质,比如kDeterminant=1以表示第一条规则,则至少需要所有后端开发达成共识,因为后端开发者注册算子时需要标注函数性质,比如src/tir/op/builtin.cc:

TIR_DEFINE_BUILTIN_FUNC(shift_left)
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
    .set_attr<TVectorizable>("TVectorizable", true);

参见此处的讨论

后记

CSE是高通公司的contribute,代码注释非常详细且论坛上有大量讨论的内容,值得仔细地研究。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值