向Relay中添加操作符

14 篇文章 6 订阅

目录

         注册操作符

创建一个调用节点

引用一个Python API 

梯度运算符

在Python中添加梯度

在C ++中添加求导

总结


为了从Relay IR中使用TVM操作符,需要在Relay中注册操作符,以确保将其集成到Relay的类型系统中。

注册操作符需要三个步骤:

  • 使用C ++中的宏RELAY_REGISTER_OP注册操作符的Arity和类型信息

  • 定义一个C ++函数为操作符生成一个调用节点,并为该函数注册一个Python API挂钩

  • 将上述Python API挂钩包装在更整洁的接口中

src/relay/op/tensor/binary.cc文件提供了前两个步骤的示例,同时python/tvm/relay/op/tensor.py提供了最后一个步骤示例。

注册操作符

TVM已经具有操作符注册表,但是如果没有其他类型信息,Relay无法正确合并TVM操作符。

为了在注册操作符时具有灵活性,并在Relay中表达类型时提供更大的表达性和粒度,使用输入和输出类型之间的关系来确定操作符类型。这些关系表示为函数,该函数接受输入类型和输出类型列表(这些类型中的任何一个都不完整)并返回满足该关系的输入和输出类型列表。本质上,操作符的关系除了计算输出类型外,还可以强制执行所有必要的键入规则(即,通过检查输入类型)。

例如,请参见src/relay/op/type_relations.h及其实现。例如,BroadcastRel接受两个输入类型和一个输出类型,检查它们都是具有相同基础数据类型的张量类型,最后确保输出类型的形状是输入类型形状的广播。

type_relations.h 如果现有类型未捕获所需运算符的行为,则可能有必要添加另一种类型关系。

C ++中的宏RELAY_REGISTER_OP允许开发人员指定以下有关Relay中的运算符的信息:

  • Arity(参数个数)

  • 位置参数的名称和说明

  • 支持级别(1表示内部固有的;较高的数字表示集成度或外部支持的操作符较少)

  • 运算符的类型关系

下面的示例来自binary.cc张量,并将其用于张量:

RELAY_REGISTER_OP("add")
    .set_num_inputs(2)
    .add_argument("lhs", "Tensor", "The left hand side tensor.")
    .add_argument("rhs", "Tensor", "The right hand side tensor.")
    .set_support_level(1)
    .add_type_rel("Broadcast", BroadcastRel);

创建一个调用节点

此步骤仅需要简单地编写一个将参数带给操作符的函数(作为Relay表达式),然后将调用节点返回给运算符(即,应将其放置在打算调用运算符的Relay AST中的节点)。

目前不支持调用属性和类型参数(最后两个字段),因此用【Op::Get】足以从操作符注册表中获取操作符信息,并将参数传递给调用节点,如下所示。

TVM_REGISTER_GLOBAL("relay.op._make.add")
    .set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {
        static const Op& op = Op::Get("add");
      return CallNode::make(op, {lhs, rhs}, Attrs(), {});
    });


引用一个Python API 

通常,Relay中的约定是,通过TVM_REGISTER_GLOBAL导出的函数应该包装在单独的Python函数中,而不是直接在Python中调用。对于产生对运算符的调用的函数,将它们捆绑起来可能很方便,如中所示python/tvm/relay/op/tensor.py,其中都提供了张量上的元素运算符。例如,以下是上一节中的add函数在Python中的显示方式:

def add(lhs, rhs):
    """Elementwise addition.

    Parameters
    ----------
    lhs : relay.Expr
        The left hand side input data
    rhs : relay.Expr
        The right hand side input data

    Returns
    -------
    result : relay.Expr
        The computed result.
    """
    return _make.add(lhs, rhs)
 

请注意,这些Python包装器也可能是向操作符提供更简单接口的好机会。例如,该 【concat】运算符被注册为仅接受一个运算符,即具有要连接的张量的元组,但是Python包装器将这些张量作为参数并将其组合成一个元组,然后生成调用节点:

def concat(*args):
    """Concatenate the input tensors along the zero axis.

    Parameters
    ----------
    args: list of Tensor

    Returns
    -------
    tensor: The concatenated tensor.
    """
    tup = Tuple(list(args))
    return _make.concat(tup)
 

梯度运算符

梯度运算符对于在Relay中编写可微分的程序很重要。尽管Relay的autodiff算法可以区分一流的语言结构,但运算符是不透明的。由于Relay无法查看具体实现,因此必须提供明确的微分规则。

Python和C ++均可用于编写梯度运算符,但是我们将示例重点放在Python上,因为它更常用。

在Python中添加梯度

可以在python/tvm/relay/op/_tensor_grad.py中找到Python梯度运算符的集合 。我们将通过两个有代表性的示例:sigmoidmultiply

@register_gradient("sigmoid")
def sigmoid_grad(orig, grad):
    """Returns [grad * sigmoid(x) * (1 - sigmoid(x))]."""
    return [grad * orig * (ones_like(orig) - orig)]
 

这里的输入是原始运算符【orig】和【grad】要累加的梯度。我们返回的是一个列表,其中第i个索引处的元素是运算符相对于运算符第i个输入的导数。通常,梯度将返回一个列表,其中包含与基本运算符输入相同数量的元素。

在进一步分析这个定义之前,首先我们应该回顾一下Sigmod函数的导数:【 σ/x=σ(x)(1σ(x))】。上面的定义看起来与数学定义相似,但是有一个重要的补充,我们将在下面进行描述。

该术语【 σ/x=σ(x)(1σ(x))】直接与导数匹配,因为这里【orig】是Sigmod型函数,但我们不仅对如何计算此函数的梯度感兴趣。我们有兴趣将此梯度与其他梯度组合,我们可以在整个程序中累积该梯度。这就是为什么出现【grad】。在表达式【grad * orig * (ones_like(orig) - orig)】中,乘以【grad】指定到目前为止的梯度如何组成导数。

现在,我们来看multiply一个更有趣的示例:

@register_gradient("multiply")
def multiply_grad(orig, grad):
    """Returns [grad * y, grad * x]"""
    x, y = orig.args
    return [collapse_sum_like(grad * y, x),
            collapse_sum_like(grad * x, y)]
 

在此示例中,返回列表中有两个元素,因为【multiply】 它是二元运算符。再回想一下,如果f(x,y)=xy,偏导数是【f/x=y∂f∂x=y】  和【 f/y=x∂f∂y=x】。

multiply】有一个必需的步骤,因为广播具有语义,因此multiply不需要sigmod 。由于【grad】的形状可能与输入 的形状不匹配,因此我们使用【collapse_sum_like 】获取【grad * <var>】项的内容,并使形状与我们要微分的输入的形状相匹配。

在C ++中添加求导

在C ++中添加渐变求导类似于在Python中添加求导,但是注册的接口略有不同。

首先,确保src/relay/pass/pattern_util.h包括在内。它提供了用于在Relay AST中创建节点的辅助功能。然后,以类似于Python示例的方式定义求导:

tvm::Array<Expr> MultiplyGrad(const Expr& orig_call, const Expr& output_grad) {
    const Call& call = orig_call.Downcast<Call>();
    return { CollapseSumLike(Multiply(output_grad, call.args[1]), call.args[0]),
             CollapseSumLike(Multiply(output_grad, call.args[0]), call.args[1]) };
}
 

请注意,在C ++中,我们不能使用与Python中相同的运算符重载,并且需要向下转换,因此实现更为冗长。即便如此,我们仍可以轻松地验证此定义是否与Python中的先前示例相同。

现在,代替使用Python装饰器,我们需要为【“FPrimalGradient”】追踪【set_attr】的调用,在基本运算符注册的末尾,以注册求导。

RELAY_REGISTER_OP("multiply")
    // ...
    // Set other attributes
    // ...
    .set_attr<FPrimalGradient>("FPrimalGradient", MultiplyGrad);


总结

  • TVM操作符可以使用表示适当类型信息的关系在Relay中注册。

  • 在Relay中使用操作符需要一个函数来为操作符生成调用节点。

  • 最好有一个简单的Python包装器来生成调用节点

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值