【TVM帮助文档学习】Relay的模式匹配

本文翻译自Pattern Matching in Relay — tvm 0.9.dev0 documentation

在TVM中,我们在很多地方会识别Relay程序的纯数据流子图,并尝试以某种方式对它们进行转换,包括融合、量化、外部代码生成等passes, 以及针对特定设备的优化,比如VTA使用的bitpacking和和layer slicing等。

如今,许多这样的pass需要大量枯燥的样板代码来实现,同时还需要用户从访问者和AST匹配的角度来考虑。这里面很多转换可以很容易地用图重写来描述。为了构建一个重写器或其他高级机制,我们首先需要一种模式语言来描述我们可以匹配的内容。

这种语言不仅对构建重写器有用,而且还为现有的pass提供扩展点。例如,融合pass可以通过一组描述硬件性能的融合模式来参数化,而量化pass可以采用一组模式来描述给定平台上哪些操作符可以量化。

在后端,我们可以使用相同的机制,使用自带的代码生成来构建更高级别的API。这个API采用一组描述硬件功能的模式和一个外部编译器,提供了一种开箱即用的相对流畅的异构体验。

模式实例

有相当多的算子属性能进行匹配。下面我们将研究如何匹配树属性,并扩展原型中没有充分探讨的一些用例。本节演示如何编写模式。建议查阅tests/python/relay/test_dataflow_pattern.py以获得更多用例。

匹配两个算子中的一个

第一个例子展示了一个简单的场景:我们想要匹配两个单输入算子中的一个:

def test_match_op_or():
    is_add_or_sub = is_op('add') | is_op('subtract')
    assert is_add_or_sub.match(relay.op.op.get("add"))
    assert is_add_or_sub.match(relay.op.op.get("subtract"))

根据属性匹配算子

接下来的例子是一个使用任何标记为element-wise的算子的dense运算:

def test_no_match_attr():
    op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
    op_pat = op(wildcard(), wildcard())
    x = relay.var('x')
    y = relay.var('y')
    assert not op_pat.match(relay.op.nn.dense(x, y))

下面是另一个使用指定属性匹配算子的例子:

def test_match_data_layout():
    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"})
    x = relay.var('x')
    y = relay.var('y')
    assert not is_conv2d.match(relay.op.nn.conv2d(x, y))

或者一个指定卷积核大小的卷积:

def test_match_kernel_size():
    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
    x = relay.var('x')
    y = relay.var('y')
    assert is_conv2d.match(relay.op.nn.conv2d(x, y, kernel_size=[3, 3]))

可选算子匹配

下面是一个带可选算子模式匹配的例子。在这个模式中,我们可以匹配conv2d+bias_add+relu的图形或conv2d+bias_add的图形。

def test_match_optional():
    conv_node = is_op('nn.conv2d')(wildcard(), wildcard())
    bias_node = is_op('nn.bias_add')(conv_node, wildcard())
    pat = bias_node.optional(lambda x: is_op('nn.relu')(x))

    x = relay.var('x')
    y = relay.var('y')
    z = relay.var('z')
    conv2d = relay.op.nn.conv2d(x, y)
    bias = relay.op.nn.bias_add(conv2d, z)
    assert pat.match(bias)
    relu = relay.op.nn.relu(bias)
    assert pat.match(relu)

 匹配类型

 除了按属性匹配,我们还可以按类型(shape和数据类型)匹配算子。下面是一些例子: 

def test_match_type():
    # Match any op with float32
    pat1 = has_dtype('float32')
    x = relay.var('x', shape=(10, 10), dtype='float32')
    assert pat1.match(x)

    # Match any op with shape (10, 10)
    pat2 = has_shape((10, 10))
    x = relay.var('x', shape=(10, 10), dtype='float32')
    assert pat2.match(x)

    # Match conv2d+relu with a certain shape
    conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
    pat3 = is_op('nn.relu')(conv2d).has_shape((1, 32, 28, 28))

    x = relay.var('x', shape=(1, 3, 28, 28), dtype='float32')
    w = relay.var('w', shape=(32, 3, 3, 3), dtype='float32')
    conv2d = relay.nn.conv2d(x, w, strides=(1, 1), padding=(1, 1))
    relu = relay.nn.relu(conv2d)
    assert pat3.match(relu)

匹配Non-Call节点

有时,我们可能还想匹配包含Tuple或TupleGetItem节点的模式。由于没有调用节点,我们需要使用特定的模式节点来匹配它们:

def test_match_tuple():
    x = relay.var('x')
    y = relay.var('y')
    z = relay.var('z')
    tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard()))
    assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))

下个例子我们匹配batch_norm -> get(0) -> relu。注意你也可以使用s_tuple_get_item(bn_node)匹配任意序号的TupleGetItem

def test_match_tuple_get_item():
    bn_node = is_op('nn.batch_norm')(wildcard(), wildcard(), wildcard(), wildcard(), wildcard())
    tuple_get_item_node = is_tuple_get_item(bn_node, 0)
    pat = is_op('nn.relu')(tuple_get_item_node)

    x = relay.var('x', shape=(1, 8))
    gamma = relay.var("gamma", shape=(8,))
    beta = relay.var("beta", shape=(8,))
    moving_mean = relay.var("moving_mean", shape=(8,))
    moving_var = relay.var("moving_var", shape=(8,))
    bn_node = relay.nn.batch_norm(x, gamma, beta, moving_mean, moving_var)
    tuple_get_item_node = bn_node[0]
    out = relay.nn.relu(tuple_get_item_node)
    pat.match(out)

如果我们有一个跨越函数边界的模式,我们可能想要匹配函数本身

def test_match_func():
    x = relay.var("x")
    y = relay.var("y")
    wc1 = wildcard()
    wc2 = wildcard()
    func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
    assert func_pattern.match(relay.Function([x, y], x + y))

下一个示例是根据常量节点的值匹配常量节点。这对于检查子图中的特定参数是否被绑定非常有用。

def test_match_constant():
    conv2d = is_op('nn.conv2d')(wildcard(), is_constant())
    pattern = is_op('nn.bias_add')(conv2d, wildcard())

    x = relay.var('x', shape=(1, 3, 224, 224))
    w = relay.var('w', shape=(3, 3, 3, 3))
    b = relay.var('b', shape=(3, ))
    conv2d = relay.op.nn.conv2d(x, w)
    out = relay.op.nn.bias_add(conv2d, b)
    func = relay.Function([x, w, b], out)
    mod = tvm.IRModule.from_expr(func)

    # Two inputs of the conv2d in the graph are VarNode by default, so no match.
    assert not pattern.match(mod['main'].body)

    # The second input (weight) has been bind with constant values so it is now a constant node.
    mod["main"] = bind_params_by_name(mod["main"],
                                    {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
    assert pattern.match(mod['main'].body)

另一方面,如果需要将常量与特定值匹配,则可以直接使用is_expr。这对代数化简很有用。

def test_match_plus_zero():
    zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0)))
    pattern = wildcard() + zero

    x = relay.Var('x')
    y = x + relay.const(0)
    assert pattern.match(y)

下一个例子是用一个特定的属性匹配函数节点:

def test_match_function():
    pattern = wildcard().has_attr({"Composite": "add"})

    x = relay.var('x')
    y = relay.var('y')
    f = relay.Function([x, y], x + y).with_attr("Composite", "add")
    assert pattern.match(f)

如果一个Relay If表达式的所有条件、真分支和假分支都匹配,则该表达式可以被匹配:

def test_match_if():
    x = is_var("x")
    y = is_var("y")
    pat = is_if(is_op("less")(x, y), x, y)

    x = relay.var("x")
    y = relay.var("y")
    cond = x < y

    assert pat.match(relay.expr.If(cond, x, y))

如果一个Relay Let表达式的所有变量、值和主体都匹配,那么它就可以被匹配:

def test_match_let():
    x = is_var("x")
    y = is_var("y")
    let_var = is_var("let")
    pat = is_let(let_var, is_op("less")(x, y), let_var)

    x = relay.var("x")
    y = relay.var("y")
    lv = relay.var("let")
    cond = x < y
    assert pat.match(relay.expr.Let(lv, cond, lv))

 匹配Diamonds和Post-Dominator图

下一个例子是匹配一个顶部有两个输入的Diamond:

def test_match_diamond():
    # Pattern
    is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
    path1 = is_op('nn.relu')(is_conv2d)
    path2 = is_op('nn.leaky_relu')(is_conv2d)
    diamond = is_op('add')(path1, path2)

    # Expr
    inp = relay.var('input')
    weight = relay.var('weight')
    conv2d = relay.op.nn.conv2d(inp, weight)
    relu = relay.op.nn.relu(conv2d)
    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
    out = relu + leaky_relu

    # Check
    assert diamond.match(out)

 最后一个例子是将diamond与post-dominator关系配对。我们在模式语言中嵌入支配者分析作为匹配类型,以允许在未知拓扑下进行模式匹配。这一点很重要,因为我们希望能够使用该语言来描述融合模式,就像单元运算后面跟着一个conv2d:

def test_match_dom_diamond():
    # Pattern
    is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
    reduction = is_op('add')(wildcard(), wildcard())
    diamond = dominates(is_conv2d, is_elemwise, reduction)

    # Expr
    inp = relay.var('input')
    weight = relay.var('weight')
    conv2d = relay.op.nn.conv2d(inp, weight)
    relu = relay.op.nn.relu(conv2d)
    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
    out = relu + leaky_relu

    # Check
    assert diamond.match(out)

模糊匹配模式

上面的Dominator分析让我们匹配一个Relay AST的子图,但使用的模式并不要求精确的一对一匹配。在有些地方我们支持“模糊”匹配。

具有任意数量输入的元组、函数和调用节点可以通过传递None作为参数值来匹配,即:

tuple_pattern = is_tuple(None)
func_pattern = FunctionPattern(None, wildcard() + wildcard())
call_pattern = func_pattern(None)

通过限制参数的使用而不是参数的数量,可以让这些模式匹配更通用的类模式。

此外,我们支持通过函数体的模糊匹配来匹配函数,即受模式约束的函数体。模式FunctionPattern([is_var(), is_var()], wildcard() + wildcard()])将匹配relay.Function([x, y], x + y),但它也将匹配relay.Function([x, y], x * x + y)。在第二种情况下,模式并没有完美地约束函数体,所以导致匹配是模糊的。

模式语言设计

Relay的模式语言被设计成Relay的IR的镜像,并对常见场景提供额外的支持。模式语言的目标是为数据流图匹配和重写提供类似正则表达式的功能。

高层次的设计是引入一种模式语言,目前我们提出的语言是:

Pattern ::= expr
        | *
        | pattern(pattern1, ... patternN)
        | has_type(type)
        | has_dtype(type)
        | has_shape(shape)
        | has_attr(attrs)
        | is_var(name)
        | is_constant()
        | is_expr(expr)
        | is_op(op_name)
        | is_tuple()
        | is_tuple_get_item(pattern, index = None)
        | is_if(cond, tru, fls)
        | is_let(var, value, body)
        | pattern1 `|` pattern2
        | dominates(parent_pattern, path_pattern, child_pattern)
        | FunctionPattern(params, body)

然后,上面的语言提供了一个匹配接口,既可以选择子图,也可以验证图是否与模式匹配。

表达式模式

匹配一个字面量表达式

通配符

匹配任意表达式

类型模式

检查与嵌套模式匹配的表达式是否具有特定类型。

数据类型模式

检查与嵌套模式匹配的表达式是否具有特定的数据类型。

shape模式

检查与嵌套模式匹配的表达式是否具有特定的输出形状。

属性模式

检查与模式匹配的算子是否具有具有特定值的属性。

变量模式

检查表达式是否为一个Relay变量,并可选地提供一个与变量名匹配的名称。

备用

要么匹配第一个模式,要么匹配第二个模式。

支配

匹配子模式,为父模式找到一个匹配,确保子模式最终主导父模式(即,模式外的节点不会使用父模式的输出),并且在子模式和模式之间的任何节点都匹配路径模式。

函数模式

使用函数体和参数匹配函数

if模式

使用条件,真分支,假分支匹配一个if表达式

Let模式

使用变量,值和语句体匹配Let语句

应用

模式语言不仅提供模式匹配,还提供模式处理。这里我们介绍了两种模式处理方法并提供了一些示例。 

模式重写

如果你想用另一个子图替换匹配的模式,你可以利用重写转换。下面是一个使用单个batch_norm操作符重写一系列算术操作符的例子。构造函数参数require_type指示是否需要在回调之前运行InferType。

class BatchnormCallback(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, require_type=False):
        super().__init__(require_type)
        self.x = wildcard()
        self.var = wildcard()
        self.mean = wildcard()
        self.beta = wildcard()
        self.gamma = wildcard()
        self.eps = wildcard()

        self.pattern = self.gamma * (self.x - self.mean)/is_op("sqrt")(self.var + self.eps) + self.beta

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        var = node_map[self.var][0]
        mean = node_map[self.mean][0]
        beta = node_map[self.beta][0]
        gamma = node_map[self.gamma][0]
        eps = node_map[self.eps][0]
        return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.numpy().item())[0]

    # A graph of arithmetic operators that are functional equivalent to batch_norm.
    x = relay.var('x')
    var = relay.var('var')
    mean = relay.var('mean')
    beta = relay.var('beta')
    gamma = relay.var('gamma')
    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta

    from tvm.relay.dataflow_pattern import rewrite
    out = rewrite(BatchnormCallback(), BN)
    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])

回调函数将在返回的模式上递归地调用,直到模式停止变更。因此,如果self. pattern匹配回调返回的图的任何部分,重写器将在循环中运行。如果你想避免多次重写,你可以给构造函数传递一个rewrite_once=True参数。

模式分割

如果你对重写不满意,想对匹配的子图执行更复杂的处理,那么可以考虑将匹配的子图分割为一个单独的Relay函数,并对该函数执行其他处理。这里我们使用pattern.partition为每个匹配的子图创建一个新的Relay函数。该功能类似于TVM中的op融合pass:

# A pattern matching conv2d+relu.
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))

# A graph.
x = relay.var('input')
w = relay.var('weight')
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
print('relu')
# free_var %x: Tensor[(1, 3, 224, 224), float32]
# free_var %w: Tensor[(3, 3, 3, 3), float32]
# %0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 222, 222), float32] */;
# free_var %b: Tensor[(3), float32]
# nn.bias_add(%0, %b) /* ty=Tensor[(1, 3, 222, 222), float32] */

# After partition.
print(pattern.partition(relu))
# free_var %x: Tensor[(1, 3, 224, 224), float32]
# free_var %w: Tensor[(3, 3, 3, 3), float32]
# free_var %b: Tensor[(3), float32]
# %1 = fn (%FunctionVar_0_0, %FunctionVar_0_1,
#          %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_") {
#   %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
#   nn.bias_add(%0, %FunctionVar_0_2)
# };
# %1(%x, %w, %b)

注意,你也可以为创建的函数指定属性:

print(pattern.partition(relu, {'Composite': 'one_layer'}))
# free_var %x: Tensor[(1, 3, 224, 224), float32]
# free_var %w: Tensor[(3, 3, 3, 3), float32]
# free_var %b: Tensor[(3), float32]
# %1 = fn (%FunctionVar_0_0, %FunctionVar_0_1,
#          %FunctionVar_0_2, Composite="one_layer",
#                            PartitionedFromPattern="nn.conv2d_nn.bias_add_") {
#   %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
#   nn.bias_add(%0, %FunctionVar_0_2)
# };
# %1(%x, %w, %b)

 如果需要一个不能使用模式语言描述的自定义检查函数,可以在分割时指定检查函数。下面的例子演示了一个检查子图输入数据布局的例子:

def check(pre):
    conv = pre.args[0]
    return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1)

pattern.partition(relu, check=check)

在这个例子中,我们检查匹配的子图的第一个参数(即pre.args[0])的数据布局是否为“NCHW”,以及它的批处理大小是否为1。如果匹配模式的条件不能通过分析模式本身来验证,则此特性非常有用。

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值