之前的文章中讲到了,MLIR可以通过两种方法将产生的表达式进行变型,即直接使用C++编写表达式的匹配与重写函数和使用DRR规则来定义重写规则。通过上述两种方法,可以将产生的表达式进行自定义的变型,比如去掉冗余
transpose
操作等等。但是,这样的方法只能针对特定的一个Operation进行变型,在遇到相似的变型时,无法重用之前的变型方法。因此本文将会介绍泛化的表达式变型。
回顾上一篇文章的表达式优化例程,我们处理了冗余的transpose
操作和reshape
操作,针对这两个操作,我们分别编写了各自的匹配和重写方法。但是对于像这种相似的变型,很显然是有一些方法可以进行重用的。因此,MLIR使用接口来实现泛化的表达式变型。
本文采用下面的例子来进行具体介绍:
def multiply_transpose(a, b) {
return transpose(a) * transpose(b);
}
def main() {
var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
var b<2, 3> = [1, 2, 3, 4, 5, 6];
var c = multiply_transpose(a, b);
var d = multiply_transpose(b, a);
print(d);
}
例子中使用Toy语言定义了multiply_transpose
方法,也就是将两个张量进行转置,再进行对应元素相乘(注意:这里的*
运算符表示对应元素相乘)。
对于这样的例子,当我们不进行任何优化生成MLIR表达式的时候,除了在实例化tensor的时候,其他时候并不能知道tensor的shape信息。如下图所示,使用的是泛化的tensor类型。
这样一来,就会使后续的代码优化和代码生成变得更加复杂。我们在上一篇文章的优化之前增加两个pass,首先进行內联,然后推断运算数据的shape,最后再消除冗余进行规范化。使用这样的方法我们就可以使每一个tensor都有一个确定的shape信息。新添加的这两个pass分别运用了Dialect接口和Operation接口。
內联(inline)pass
使用內联pass,可以用定义的函数体替代函数的调用。对于multiply_transpose
这种函数,函数的调用、返回有关的准备和收尾工作的代码往往比函数体本身的代码要大得多。因此,对于这类简单的、使用频繁的小函数,将之声明为内联函数可提高运行效率。那么在上述示例中,应该如何实现內联的pass呢?
首先,我们要实现內联pass的表达式变型规则,这里就要提到了Dialect接口,MLIR提供了DialectInlinerInterface
接口,也就是说MLIR本身提供了处理內联的接口。这样一来,我们只需要继承DialectInlinerInterface
,然后再实现所需函数,就可以自定义內联的操作,制定表达式变型规则。
除此之外,內联的过程还需要知道在哪里进行內联,这就要求知道在main函数中调用multiply_transpose
的位置。这里我们将要提到的是Operation接口,MLIR中提供了一个CallOpInterface
接口, 将会对函数调用进行标记,以此告知內联操作函数调用的位置。Operation接口的粒度比Dialect接口的小,也就是说Dialect接口的作用范围是整个Dialect,Operation接口就是针对特定Operation进行操作。
在确定了內联操作的表达式变型规则以及定位到函数调用位置之后,下面解决的问题就是,函数定义时的类型和函数调用时的类型不一致,也就是说当函数定义时的参数使用的是泛化的tensor类型,但是函数调用时的参数是确定类型的tensor,在內联操作时需要对其类型进行统一。为了解决这样的问题,我们需要增加一个CastOp,用来将变量确定的类型转变为泛化的类型。在Operation模块中定义CastOp,再在Dialect模块的內联操作接口中,重写相应函数,从而实现在MLIR表达式中加入CastOp,将确定的类型转化为泛化的类型。这样一来,在函数调用处,內联操作就可以将函数定义,嵌入到函数调用位置。
到此为止,我们实现了內联操作,现在的MLIR表达式已经没有了函数定义的表达式,全部嵌入了main函数中。
Shape推断 Pass
我们通过CastOp实现了将确定类型的tensor转变成了泛化类型的tensor ,才使內联操作得以完成。下一步要解决的就是,根据那些确定类型tensor,将那些泛化的tensor都转变为确定类型的tensor,这里我们将要用ODS框架来生成自定义的Operation接口,这个接口就是用来推测tensor的shape类型。整个Shape推断也将会编写为一个pass作用在MLIR表达式上。
首先,使用ODS框架的规则编写Shape推断的接口模块。由于这是Operation接口,所以是作用在特定的Operation上的,因此下一步就是要指定哪些Operation将会使用Shape推断接口。除此之外,还要为各个使用Shape推断接口的Operation,实现接口中的inferShapes
函数。最后定义一个Shape推断接口的类,提供Shape推断算法,并创建一个Pass,在PassManger中进行添加。
其中,Shape推断的算法如下图所示:
到此为止,我们实现了內联操作,并且成功推测了各个泛化tensor的Shape信息。我们使用了Dialect接口使用了泛化的內联操作接口,并进行自定义操作;使用了MLIR提供的Operation接口定位函数调用位置;使用了ODS框架定义Operation接口,进行Shape推断。再配合其他的几个规范化的Pass,得到了最终变型后的MLIR表达式的结果。
借助这个例子也说明了使用泛化的重用表达式变型的操作,同时也说明了MLIR可重用性与可扩展性。
本文参考自MLIR官方文档,如有错误纰漏,欢迎大家批评指正。