【TVM源码学习笔记】2.2 C++侧的relay ir op, function和irmodule

 1 Relay ir算子实现

前文分析onnx卷积算子转tvm relay ir,最后是从python调用到在C++侧的MakeConv接口: 

template <typename T>
inline Expr MakeConv(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
                     Array<IndexExpr> dilation, int groups, IndexExpr channels,
                     Array<IndexExpr> kernel_size, std::string data_layout,
                     std::string kernel_layout, std::string out_layout, DataType out_dtype,
                     std::string op_name) {
  auto attrs = make_object<T>();
  attrs->strides = std::move(strides);
  attrs->padding = std::move(padding);
  attrs->dilation = std::move(dilation);
  attrs->groups = groups;
  attrs->channels = std::move(channels);
  attrs->kernel_size = std::move(kernel_size);
  attrs->data_layout = std::move(data_layout);
  attrs->kernel_layout = std::move(kernel_layout);
  attrs->out_layout = std::move(out_layout);
  attrs->out_dtype = std::move(out_dtype);
  const Op& op = Op::Get(op_name);
  return Call(op, {data, weight}, Attrs(attrs), {});
}

代码中先保存了传入的卷积参数,然后调用Op::Get获取一个Op对象实例,将Op、输入数据、权重数据和属性作为参数,返回一个Call实例。

1 Op::Get

Op类定义在include/tvm/ir/op.h中。其中Op::Get方法代码:

// find operator by name
const Op& Op::Get(const String& name) {
  const OpRegEntry* reg = OpRegistry::Global()->Get(name);
  ICHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered";
  return reg->op();
}

从注释看是根据名字查找算子。OpRegistry是

using OpRegistry = AttrRegistry<OpRegEntry, Op>;

类和Get的定义:

/*!
 * \brief Implementation of registry with attributes.
 *
 * \tparam EntryType The type of the registry entry.
 * \tparam KeyType The actual key that is used to lookup the attributes.
 *                 each entry has a corresponding key by default.
 */
template <typename EntryType, typename KeyType>
class AttrRegistry {
 public:
  using TSelf = AttrRegistry<EntryType, KeyType>;
  /*!
   * \brief Get an entry from the registry.
   * \param name The name of the item.
   * \return The corresponding entry.
   */
  const EntryType* Get(const String& name) const {
    auto it = entry_map_.find(name);
    if (it != entry_map_.end()) return it->second;
    return nullptr;
  }

1. 在Op::Get里面,模板参数EntryType和KeyType分别是OpRegEntry类和Op类;

2. Op::Get里面OpRegistry::Global()->Get(name)这种调用方式,说明OpRegistry是个单实例类,所以上面代码中的OpRegistry::entry_map_也是个全局唯一的;

3.OpRegistry::Get是从entry_map_中根据传入的名字name找到一个表项的key。对2D卷积算子来说,这里参数name为nn.conv2d

这个全局唯一的entry_map_是什么时候往里面写数据的呢?写数据的接口是OpRegistry::RegisterOrGet:

  /*!
   * \brief Get an entry or register a new one.
   * \param name The name of the item.
   * \return The corresponding entry.
   */
  EntryType& RegisterOrGet(const String& name) {
    auto it = entry_map_.find(name);
    if (it != entry_map_.end()) return *it->second;
    uint32_t registry_index = static_cast<uint32_t>(entries_.size());
    auto entry = std::unique_ptr<EntryType>(new EntryType(registry_index));
    auto* eptr = entry.get();
    eptr->name = name;
    entry_map_[name] = eptr;
    entries_.emplace_back(std::move(entry));
    return *eptr;
  }

在RegisterOrGet中,先是在entry_map_表中查找key为name表项是否存在;如果存在,直接返回该表象的value;如果不存在,就new一个EntryType类型实例,然后用get获取类型指针(由entry_map_的定义倒推,eptr为EntryType*类型),将这个数据实例加入到entry_map_表和entries_表中。

搜索RegisterOrGet的调用,可以看到在RELAY_REGISTER_OP宏的定义中有使用:

#define RELAY_REGISTER_OP(OpName) TVM_REGISTER_OP(OpName)


#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)

#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op


#define TVM_REGISTER_OP(OpName)                          \
  TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
      ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()

RELAY_REGISTER_OP用于注册一个算子。 详细分析可以参考
深入理解TVM:RELAY_REGISTER_OP

TVM_OBJECT_REG_VAR_DEF定义了一个静态变量,包括变量类型 (::tvm::OpRegEntry&)和变量名的前半部分;

TVM_STR_CONCAT将__COUNTER__和前半部分拼接在一起成为完整的变量名。

__COUNTER__宏是一个计数器,保证在编译过程中产生一个独一无二的数字,这样这个拼接后的变量名也将是独一无二的。

算子实现和注册可以参考

Adding an Operator to Relay 

例如conv2d的注册:

RELAY_REGISTER_OP("nn.conv2d")
    .describe(R"code(2D convolution layer (e.g. spatial convolution over images).

This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.

- **data**: This depends on the `layout` parameter. Input is 4D array of shape
            (batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
- **out**:  This depends on the `layout` parameter. Output is 4D array of shape
            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.

)code" TVM_ADD_FILELINE)
    .set_attrs_type<Conv2DAttrs>()
    .set_num_inputs(2)
    .add_argument("data", "Tensor", "The input tensor.")
    .add_argument("weight", "Tensor", "The weight tensor.")
    .set_support_level(2)
    .add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);

综上所述,我们理下整个流程:

1. 当tvm中实现一个算子时,会调用 RELAY_REGISTER_OP进行注册;

2. 该注册会在 AttrRegistry<OpRegEntry, Op>(这是个单例模式的类)的entry_map_中加入一个OpRegEntry实例;

3. 而tvm处理一个外部输入的模型时,如果遇到这个算子,在Op::Get方法中从entry_map_表中读取对应的OpRegEntry实例:

const OpRegEntry* reg = OpRegistry::Global()->Get(name)

并执行

return reg->op()

获取和返回对应的Op实例:

class OpRegEntry {
 public:
  /*! \return the operator */
  const Op& op() const { return op_; }
  ...
 private:
  ...
  /*! \brief The operator */
  Op op_;
  ...
}

OpRegEntry::OpRegEntry(uint32_t reg_index) {
  ObjectPtr<OpNode> n = make_object<OpNode>();
  n->index_ = reg_index;
  op_ = Op(n);
}

Op类继承自ObjectRef,对应的数据类型为OpNode。在该类中记录了算子的名称,类型,属性,输入等信息。还提供了属性的访问入口VisitAttrs

2 返回Call实例

回过头来我们继续看MakeConv。在调用Op::Get获取算子实例后,执行

 return Call(op, {data, weight}, Attrs(attrs), {});

返回一个Call实例。我们看下Call的定义:

class Call : public Expr {
 public:
  /*!
   * \brief The destructor
   */
  ~Call();

  /*!
   * \brief The constructor
   * \param op The operator will be invoked.
   * \param args The arguments of the call.
   * \param attrs The attributes of the call node.
   * \param type_args The type arguments passed to a polymorphic function.
   * \param span The source span of the expression.
   */
  TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
               Array<Type> type_args = Array<Type>(), Span span = Span());

  TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
  TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};

以及构造函数:

Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) {
  ObjectPtr<CallNode> n = make_object<CallNode>();
  n->op = std::move(op);
  n->args = std::move(args);
  n->attrs = std::move(attrs);
  n->type_args = std::move(type_args);
  n->virtual_device_ = VirtualDevice::FullyUnconstrained();
  n->span = std::move(span);
  data_ = std::move(n);
}

Call实例化的时候,生成了一个CallNode,记录了当前算子的参数和注册Op对象。所以我们可以简单的理解这里只是为当前算子生成了一个对象,记录下算子类型和参数而已。

这里需要注意的有两点:

1,传入的op对象没有做任何修改。因为本身这个op对象也是算子注册时生成的,全局唯一的一个实例,它记录了一些算子的公共信息,与当前算子的参数无关;

2. 这里生成了一个CallNode。Call和CallNode是什么关系呢?这个就涉及到TVM C++代码的对象基石Object家族。要分析TVM C++代码,有必理解Object类。详细分析参见【TVM源码学习笔记】附录2 TVM的Object家族

2 relay Function 

python前端在将解析的模型打包成Function时,调用到C++时,返回的是一个relay::Function实例:

TVM_REGISTER_GLOBAL("relay.ir.Function")
    .set_body_typed([](tvm::Array<Var> params, Expr body, Type ret_type,
                       tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
      return Function(params, body, ret_type, ty_params, attrs);
    });

relay::Function类比较简单,只有一个构造函数和一些ObjectRef类需要设置的公共属性。我们看下构造函数:

Function::Function(tvm::Array<Var> params, Expr body, Type ret_type,
                   tvm::Array<TypeVar> type_params, DictAttrs attrs, Span span) {
  ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
  ICHECK(params.defined());
  ICHECK(type_params.defined());
  n->params = std::move(params);
  n->body = std::move(body);
  n->ret_type = std::move(ret_type);
  n->type_params = std::move(type_params);
  n->attrs = std::move(attrs);
  n->virtual_device_ = VirtualDevice::FullyUnconstrained();
  n->span = std::move(span);
  data_ = std::move(n);
}

这里只是创建了一个FunctionNode实例,记录下传入的函数体、参数和属性。

3 relay IRModule

from_onnx最后返回IRModule的代码:

return IRModule.from_expr(func), self._params

这个最后调用的是C++里面注册的ir.Module_FromExpr的:

TVM_REGISTER_GLOBAL("ir.Module_FromExpr").set_body_typed(&IRModule::FromExpr);

执行函数体中的IRModule::FromExpr,实现如下:

IRModule IRModule::FromExpr(const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs,
                            const Map<GlobalTypeVar, TypeData>& type_definitions) {
  return FromExprInContext(expr, global_funcs, type_definitions).first;
}

这里从python传入的func作为expr参数了。 在函数声明的时候,后面两个参数默认是空:

  /*!
   * \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no
   * imports.
   */
  TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
                                   const Map<GlobalVar, BaseFunc>& global_funcs = {},
                                   const Map<GlobalTypeVar, TypeData>& type_definitions = {});

所以当前上下文环境下,我们忽略这两个参数。

std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
    const RelayExpr& expr, const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
    const tvm::Map<GlobalTypeVar, TypeData>& type_definitions,
    std::unordered_set<String> import_set) {
  auto mod = IRModule(global_funcs, type_definitions, std::move(import_set));
  String gv_name;

  // All global definitions must be functions.
  BaseFunc func;
  if (auto* func_node = expr.as<BaseFuncNode>()) {
    func = GetRef<BaseFunc>(func_node);
    if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
      // Function literal has been annotated with it's required global symbol.
      gv_name = opt.value();
    }
  } else {
    func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
  }

  if (gv_name.empty()) {
    // Bind function to 'main' (though rename if would clash with existing 'main').
    gv_name = mod->GetUniqueName("main");
  }

  GlobalVar main_gv(gv_name);
  mod->Add(main_gv, func);
  return {mod, main_gv};
}

这里先创建了一个IRModule对象,然后将传入的函数命名为main,然后将全局变量main_gv和函数func添加到mod中。

在TVM中,一般使用Node的定义继承自Object,不带Node的继承自ObjectRef。IRModule也不例外。所以IRModule是一个引用类型,真正存储数据的是IRModuleNode。Add方法也是IRModule的方法:

void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
  BaseFunc checked_func = f;
  if (auto* ptr = f.as<relay::FunctionNode>()) {
    WarnIfMalformed(GetRef<IRModule>(this), GetRef<relay::Function>(ptr));
  }

  AddUnchecked(var, checked_func);
}

void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
  this->functions.Set(var, func);

  auto it = global_var_map_.find(var->name_hint);
  if (it != global_var_map_.end()) {
    ICHECK_EQ((*it).second, var);
  } else {
    ICHECK(global_var_map_.count(var->name_hint) == 0)
        << "Duplicate global function name " << PrettyPrint(var);
  }

  global_var_map_.Set(var->name_hint, var);
}

可以看到,全局变量main_gv最后被加入了全局变量表,func加入了全局函数表

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值