理论
预备知识
- TIR Let Binding
Let (var, value, body)
将value求值赋给var,然后返回body的求值结果。let将表达式 Expr 绑定到局部作用域的不可变 var 中。
- scope:作用域/代码块。 TVM AST将作用域组建为树形结构。外侧作用域包含内侧作用域的关系被表示为父节点和子节点。父节点的变量对子节点可见但反之不然。判断变量是否在某个作用域内是CSE算法的一个重要的部分。
是什么
TIR代码来源于relay和其他Pass。生成TIR代码的过程是自动的,因此会有很多重复。Common Subexpression Elimination (CSE,公共子表达式消除) 是TIR的Pass之一,旨在定位并替换重复的计算。
- 创建一个新变量并替换所有的表达式。
- 支持完整的表达式替换。
- 支持子表达式替换。比如
(w+x)+(y+z); (w+x)+u; => new_var = (w+x); new_var+(y+z); new_var+u;
,(w+x)
就是子表达式。
原理
前提
TIR的SSA(Static Single Assignment)性质:变量的值不变(immutable)。如果没有这个前提,替换就会出问题。比如拿y=a+b替换所有的a+b,但是a的值在某处被修改了,那么之后的y=a+b就变成了y=0+b,如果不重新计算a+b的值,就会有错。
定位重复的表达式
筛选出候选的表达式
- 表达式不是常量或变量(已经是变量了就没必要创一个新变量去替换了)
- 表达式不是function call或者memory load
- 函数不一定是pure的。对于有副作用(side effects)的函数,即使函数名和参数一样,返回的结果可能不同。
- 同理,两次memeory load返回的结果可能不同。如果两个相同的表达式中间出现了一次memory load,则这个表达式不能作为候选表达式。
- 表达式也不包含(子)function call或者memory load。
- 替换
sum(f,f)
也是不安全的,因为f可能有side effects。
- 替换
判断候选表达式是否应该继续处理
表达式所使用的变量必须在当前的scope下且频繁出现。
对于不满足以上条件的表达式,递归地考虑它的子表达式。
比如(w+x)*(y+z)
包含的(w+x)
和(y+z)
数据结构
- Context:上下文,
vector<pair<Var,MaybeValue>>
。- 知道哪个变量在当前的scope下
- table of computations:表达式计数表,unordered_map。
- key是表达式PrimExpr,例如在Stmt
buffer[i1] = ((x + y) + z)
中,((x + y) + z)
和(x+y)
都是PrimExpr,后者可以由Visit而其子表达式 - value是其出现的次数。
- key是表达式PrimExpr,例如在Stmt
创建新变量
考虑到新变量的表达式之间可能会有包含关系,需要将表示长表达式的新变量放在在let作用域内部。短的表达式的变量则在外部。
如图,外侧smallComp可以是y=a+b, z=d+e,内侧bigComp则可以是p=y+z。
缺点与改进
- 未来可以支持丰富的语义结构。比如
(x+y)+z <=> z+(x+y)
- 区分出side effects的函数,以便进行更深的优化。
参考资料
TVM Conference 2021 Qualcomm
TVM 拆包(一):Runtime basics
代码实现
先来看tests/python/unittest/test_tir_transform_common_subexpr_elim.py::test_cse:
@main = primfn(i1: int32, i2: int32, z3: int32) -> () {
let z1: int32 = 1
let z2: int32 = 2
{
buffer: Pointer(int32)[i1] = (z1 + z2)
let x: int32 = 1
let y: int32 = 1
let a: int32 = ((x + y) + (z1 + z2))
let b: int32 = ((x + y) + z3)
buffer[i2] = (a + b)
}
}
[15:01:55] /home/yuan/Coding/compiler/repos/tvm/src/ir/transform.cc:566: PrintIR():
#[version = "0.0.5"]
@main = primfn(i1: int32, i2: int32, z3: int32) -> () {
let z1: int32 = 1
let z2: int32 = 2
let cse_var_1: int32 = (z1 + z2)
{
buffer: Pointer(int32)[i1] = cse_var_1
let x: int32 = 1
let y: int32 = 1
let cse_var_2: int32 = (x + y)
let a: int32 = (cse_var_2 + cse_var_1)
let b: int32 = (cse_var_2 + z3)
buffer[i2] = (a + b)
}
}
观察到,CSE生成了两个变量cse_var_1=z1+z2, cse_var_2=x+y
,代替了相应的表达式。替换规则:z1+z2,x+y两次出现在同一scope下,且没有关于变量的load操作。
test_tir_transform_common_subexpr_elim.py::test_cse_cascade:
yuan@yuan:~/Coding/compiler/repos/tvm$ python -m pytest /home/yuan/Coding/compiler/repos/tvm/tests/python/unittest/test_tir_transform_common_subexpr_elim.py::test_cse_cascade -s
enabled targets: llvm; cuda; nvptx
pytest marker:
====================================================================== test session starts ======================================================================
platform linux -- Python 3.8.10, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /home/yuan/Coding/compiler/repos/tvm
collected 1 item
tests/python/unittest/test_tir_transform_common_subexpr_elim.py @main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {
buffer: Pointer(int32)[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)
}
[15:17:37] /home/yuan/Coding/compiler/repos/tvm/src/ir/transform.cc:566: PrintIR():
#[version = "0.0.5"]
@main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {
let cse_var_2: int32 = (x + y)
let cse_var_1: int32 = (cse_var_2 + z)
{
buffer: Pointer(int32)[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2
}
}
替换规则对应了之前讲到的内容:长表达式在内,短表达式在外。
看完了样例,我们对照第二个样例分析源码的执行过程。在递归入口,函数Input和Output,计数哈希表处插入日志来观察调用逻辑:
yuan@yuan:~/Coding/compiler/repos/tvm$ python -m pytest /home/yuan/Coding/compiler/repos/tvm/tests/python/unittest/test_tir_transform_common_subexpr_elim.py::test_cse_cascade -s
enabled targets: llvm; cuda; nvptx
pytest marker:
============================================================= test session starts ==============================================================
platform linux -- Python 3.8.10, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /home/yuan/Coding/compiler/repos/tvm
collected 1 item
tests/python/unittest/test_tir_transform_common_subexpr_elim.py @main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {
buffer: Pointer(int32)[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/runtime/logging.cc:239: TVM_LOG_DEBUG enables VLOG statements in 'tir/transforms/common_subexpr_elim.cc' up to level 1
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
(((x + y) + z), 2)
((x + y), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:493: variables_created true
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
let cse_var_1 = ((x + y) + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = (x + y)
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
(((x + y) + z), 1)
((x + y), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:493: variables_created true
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
((x + y), 1)
((cse_var_2 + z), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
((cse_var_2 + z), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i1] = cse_var_1
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i1] = cse_var_1
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i2] = cse_var_1
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i2] = cse_var_1
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i3] = cse_var_2
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i3] = cse_var_2
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/ir/transform.cc:566: PrintIR():
#[version = "0.0.5"]
@main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {
let cse_var_2: int32 = (x + y)
let cse_var_1: int32 = (cse_var_2 + z)
{
buffer: Pointer(int32)[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2
}
}
按照Log的输出,模拟了一个执行流程。方框内包含了原始的计算图。按照从下往上的顺序,依次加入了cse_var_1, cse_var_2
。而在遍历子节点时,又从上往下使用DFS。
Context更新/调用入口
PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
// At this point, we have already done the generic treatment of introducing (via let-in) what
// was doable at the toplevel of the given let-in.
// Save the context at the entry of the function
Context context_at_entry = context_;
// Recurse on the `value` field for potentially rewriting it
PrimExpr value_new = VisitExpr(op->value);
// Augment the context with the association (`var`, `value`) for preparing the next recursion
// on the `body`
context_.push_back({op->var, MaybeValue(op->value)});
// Recurse on the `body` (with this extended context)
// The recursive call will have potentially done new simplifications, because in this recursive
// call `var` will be a part of the context.
// (see in VisitExpr() that no introduction were performed when a computation was using an
// undefined variable, as that would lead to ill-formed code)
PrimExpr body_new = VisitExpr(op->body);
// Restaure the context to its content at the entrance to not carry out of scope declarations
// as the variable introduced by the let-in is not in scope outside of its body
context_ = context_at_entry;
// Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
// have been done.
// If the `value` and the `body` of the let-in have been rewritten to the same thing
if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
// then return a reference to the same node
return GetRef<PrimExpr>(op);
} else {
// Otherwise return a let-in built with the new `value_new` and the new `body_new` that
// have just been obtained
return Let(op->var, value_new, body_new, op->span);
}
}
CommonSubexpressionEliminator::VisitExpr_(const LetNode* op)作为调用VisitExpr的函数入口,做了两件事:
- 先递归遍历Var
- 将let表达式包含的变量加入上下文,准备下一步递归遍历body
- 递归遍历body,此时let包含的变量存储在context中,因此在body中可见
- 还原context
候选表达式规则
bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
return (
// In order to be eligible, the given expression should not be a constant
(expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
(expr.as<StringImmNode>() == nullptr)
// and it should not be a variable
&& (expr.as<VarNode>() == nullptr)
// and it should not be a forbidden computation (function calls and loads)
&& (!ForbiddenComputation(expr))
// and it should not even contain a forbidden computation (function calls and loads)
// the reason is that we don't want to register expressions like (x + f(y)) or
// (x + Mem[i]) as introducing them into variables could change the semantics
&& (!CheckContains::ExprContains(expr, ForbiddenComputation))
// and it should not be a ramp node or a broadcast node due to some internals TVM
// constraints (which check for these node explicitely without performing any
// evaluation first, so if they have been put into variables it fails)
&& (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
}
关于ComputationTable
// Obtain the (syntactic) eligible computations done by the input statement, and keep it as
// a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the
// number of time this exact syntactic computation is being computed.
ComputationTable table_syntactic_comp_done_by_stmt = ComputationsDoneBy::GetComputationsDoneBy(
stmt, IsEligibleComputation, CanContainEligibleComputations);
...
std::unordered_map<Stmt, ComputationTable, ObjectPtrHash, ObjectPtrEqual>
cache_stmt_table_computations_;
...
void ComputationsDoneBy::VisitStmt(const Stmt& stmt) {
// See if we have already computed the (table of) computations done by `stmt`
auto it_table_stmt = cache_.cache_stmt_table_computations_.find(stmt);
if (it_table_stmt != cache_.cache_stmt_table_computations_.end()) {
// We need to do the union with `table_of_computations_` instead of just writing into it,
// because some other childs might have added things into it too. The reason for that is
// that `table_of_computations_` is shared between the child nodes of a given statement.
UnionOfComputationTables(&table_of_computations_, it_table_stmt->second);
return;
}
// If we reach this point, it means that we have never computed before the computations done
// by `stmt` and will do so now.
// The computations done by a Stmt node are just the ones done by its children
ComputationTable temp =
ComputationsDoneByChildrenOf(stmt, is_eligible_computation_, can_contain_computations_);
// We need to do the union with `table_of_computations_` instead of just writing into it,
// because some other childs might have added things into it too. The reason for that is
// that `table_of_computations_` is shared between the child nodes of a given expression.
UnionOfComputationTables(&table_of_computations_, temp);
}
对于每个Stmt,程序保存一个unordered_map<Stmt, ComputationTable>
。使用时,拿着Stmt查找对应的ComputationTable。
ComputationTable ComputationsDoneBy::ComputationsDoneByChildrenOf(
const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation,
std::function<bool(const PrimExpr&)> can_contain_computations) {
// We will be using an instance of the class ComputationsDoneBy for the child nodes
// (ie, they will share the "result" that `table_of_computations_` is)
ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations);
// Calls the *dispatcher* (not the overriden method)
computations_done_by.StmtExprVisitor::VisitStmt(stmt);
// So now we can copy table_of_computations_ into the cache for the future queries
// Note : in the table, the computations done by `stmt` is set to the computations done by its
// children, because that's exactly what we mean by "the computations of a statement".
cache_.cache_stmt_table_computations_[stmt] = computations_done_by.table_of_computations_;
return computations_done_by.table_of_computations_;
}
ComputationsDoneByChildrenOf和ComputationsDoneBy::VisitStmt实际上是互相递归调用的,因为
// Calls the *dispatcher* (not the overriden method)
computations_done_by.StmtExprVisitor::VisitStmt(stmt);
。
StmtExprVisitor::VisitStmt的作用是递归地访问Stmts和它的表达式。
可以模拟执行顺序:
ComputationsDoneBy::VisitStmt 传入Stmt -> ComputationsDoneByChildrenOf -> computations_done_by.StmtExprVisitor::VisitStmt(stmt) 传入下一个Stmt -> ComputationsDoneBy::VisitStmt -> …
为了方便理解,在ComputationsDoneByChildrenOf里打印日志:
[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)
[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)
[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i1] = ((x + y) + z)
[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i2] = ((x + y) + z)
[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i3] = (x + y)
[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
(((x + y) + z), 2)
((x + y), 1)
}
扩展:语法转语义
在之前的“缺点与改进“中提到。SyntacticToSemanticComputations和EquivalentTerms等等是单独的一个模组。
目的是支持:
- 交换律
(x+y <=> y+x)
- 结合律
(x+y)+z <=> x+(y+z)
- 分配律
x*(y+z) <=> x*y+x*z
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
// containing *semantic* entities, i.e. where equivalent computations are merged.
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =
SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt);
...
/*!
* \brief Decides if two terms are equivalent semantically
*/
bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) {
// For now, we just check the syntactic equality, but that could later become a semantic test,
// for instance identifying computations modulo commutativity (like x+y and y+x), or modulo
// associativity (like (x+y)+z and x+(y+z)), etc.
arith::Analyzer analyser;
PrimExpr a_simplified = analyser.Simplify(a);
PrimExpr b_simplified = analyser.Simplify(b);
return EqualTerms(a_simplified, b_simplified);
}
上面是我在EquivalentTerms上做的改动,tvm/arith下支持一部分语义分析。我也和开发者讨论过一次。结果是虽然不能完全覆盖所有情况但是聊胜于无。
创建新变量
按照从长到短的规则,对数据结构按照其长度(Complexity)降序排序。排序后遍历,并用Let语句包含。长的表达式会在内侧,然后被后来添加的短表达式所使用的Let语句块覆盖。
std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
});
for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
...
Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
...
result = Let(new_var, computation_and_nb.first, result);
}
用新变量替换所有Occurence
// Replace in the current `result` everything that is selected by the selector with
// the new variable, without diving into expressions in which we don't have the
// right to dive.
result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr(result, predicate_selector, new_var,
CanContainEligibleComputations);
...
class ReplaceSelectedExpr : public StmtExprMutator
...
PrimExpr ReplaceSelectedExpr::VisitExpr(const PrimExpr& expr) {
// If the current expression is selected by the predicate
if (predicate_selector_(expr)) {
// Then simply return the new expression
return new_expr_;
} else {
// If replacing inside the current expression is allowed
if (can_replace_inside_(expr)) {
// then we continue the exploration recursively
return StmtExprMutator::VisitExpr(expr);
} else {
// otherwise we simply return the current expression
return expr;
}
}
}
分清况讨论:
- 如果当前表达式被predicate_selector_选中,则返回new_expr_。这里完成了代替过程。
- 如果当前表达式不被predicate_selector_选中,
StmtExprMutator::VisitExpr(expr);
将对其子表达式递归调用ReplaceSelectedExpr::VisitExpr
。ComputationTable中也有类似的调用逻辑。 - 否则不做改动,返回expr自己。
扩展:pure属性的函数
最早的Commit禁止了函数的优化。原因在之前已经说过了。但是更深的优化可以利用函数是否纯,来决定是否可以进行替换。原作者FrankQC认为,决定优化的规则有两个,一是函数是否对同一组参数有相同的输出,二是函数是否修改了外部状态。在笔者写这篇文章的时候,函数的性质可以注册为:
enum class CallEffectKind : int {
/*! \brief Function corresponds to an annotation(e.g. likely) and can translate to identity. */
kExprAnnotation = 0,
/*!
* \brief Pure function that do not interacts
* with any external state.
*/
kPure = 1,
/*!
* \brief Function's that may read from states(e.g. RAM)
*/
kReadState = 2,
/*!
* \brief Function that may read/write from states(e.g. RAM).
*/
kUpdateState = 3,
/*!
* \brief Opaque function, cannot make any assumption
*/
kOpaque = kUpdateState,
/*!
* \brief Special intrinsic to annotate call arguments info
* only valid as a direct argument to a call.
*/
kSpecialCallArg = 4,
/*!
* \brief Embed opaque information in the Expr, cannot be codegen.
*/
kEmbedInfo = 5,
/*!
* \brief Function that changes control flow
*/
kControlJump = 6,
};
被注册为kPure=1
的函数,其语义为“do not interacts with any external state"。故不能草率的认为系统内所有kPure=1的函数均为纯。如果要引入新的性质,比如kDeterminant=1
以表示第一条规则,则至少需要所有后端开发达成共识,因为后端开发者注册算子时需要标注函数性质,比如src/tir/op/builtin.cc:
TIR_DEFINE_BUILTIN_FUNC(shift_left)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TVectorizable>("TVectorizable", true);
参见此处的讨论。
后记
CSE是高通公司的contribute,代码注释非常详细且论坛上有大量讨论的内容,值得仔细地研究。