vs 生成get set_TVM代码学习 -- 代码生成流程(一)

本文主要介绍TVM针对不同后端部署平台生成运行代码的流程,TVM可以通过两种方式生成代码:tvm.build 和 relay.build。两种方式流程略有不同,tvm.build 主要针对单一算子进行编译优化,relay.build 是针对整个网络计算图进行编译优化。先介绍relay.build,示例代码如下所示。

relay.build

onnx_model = onnx.load('model/mobilenetv2.onnx')
mod,params = relay.frontend.from_onnx(onnx_model,shape_dict)

dtype = 'float32'
with relay.build_config(opt_level=0):
    graph, lib, params = relay.build_module.build(mod, target, params=params)

对relay.build进行代码跟踪,首先进入tvm/python/tvm/relay/build_module.py(这里是因为在relay/__init__.py中将build函数直接import到relay的命名空间,因此跳过了build_module这一层),其中的build函数是build_module内的全局函数)。

def build(mod, target=None, target_host=None, params=None, mod_name='default'):
    // ignore some code.....

    # If current dispatch context is fallback context (the default root context),
    # then load pre-tuned parameters from TopHub
    if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
        tophub_context = autotvm.tophub.context(list(target.values()))
    else:
        tophub_context = autotvm.util.EmptyContext()

    with tophub_context:
        bld_mod = BuildModule()
        graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
        mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, mod, mod_name, params)
        return mod

首先是寻找AutoTVM是否有预先tune好的参数记录,然后构造tophub_context,在该域内创建BuildModule实例并调用build方法。BuildModule类实现在build_module.py中,部分代码如下。

class BuildModule(object):
    """Build an IR module to run on TVM graph runtime. This class is used
    to expose the `RelayBuildModule` APIs implemented in C++.
    """
    def __init__(self):
        self.mod = _build_module._BuildModule()
        self._get_graph_json = self.mod["get_graph_json"]
        self._get_module = self.mod["get_module"]
        self._build = self.mod["build"]
        self._optimize = self.mod["optimize"]
        self._set_params_func = self.mod["set_params"]
        self._get_params_func = self.mod["get_params"]

    def build(self, mod, target=None, target_host=None, params=None):
        target = _update_target(target)

        # Setup the params.
        if params:
            self._set_params(params)
        # Build the IR module
        self._build(mod, target, target_host)
        # Get artifacts
        graph_json = self.get_json()
        mod = self.get_module()
        params = self.get_params()

        return graph_json, mod, params

而_build_module._BuildModule()又通过FFI在python/tvm/relay/_build_module.py中与C++函数建立联系。

"""The interface for building Relay functions exposed from C++."""
import tvm._ffi

tvm._ffi._init_api("relay.build_module", __name__)

对应的C++函数实现在tvm/src/relay/backend/http://build_module.cc中。

runtime::Module RelayBuildCreate() {
  auto exec = make_object<RelayBuildModule>();
  return runtime::Module(exec);
}

TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
  *rv = RelayBuildCreate();
});

也就是注册一个RelayBuildModule供调用,可以在RelayBuildModule类中看build函数实现,也在http://build_module.cc文件中。

class RelayBuildModule : public runtime::ModuleNode {
 public:
  PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
   // ignore some code ....
    } else if (name == "build") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        CHECK_EQ(args.num_args, 3);
        this->Build(args[0], args[1], args[2]);
      });
   // ignore some code ....
    }

TVM对build函数做了一次封装返回一个PackedFunc,即RelayBuildModule类成员函数Build:this->Build(....)。

/*!
   * brief Build relay IRModule for graph runtime
   *
   * param mod Relay IRModule
   * param target Target device
   * param target_host Host target device
   */
  void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) {
    targets_ = targets;
    target_host_ = target_host;
    BuildRelay(mod, params_);
    // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096.
    CompileEngine::Global()->Clear();
  }

会进一步调用成员函数BuildRelay。

void BuildRelay(IRModule relay_module,
                  const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
    // Relay IRModule -> IRModule optimizations.
    relay_module = Optimize(relay_module, targets_, params);
    // Get the updated function.
    auto func = Downcast<Function>(relay_module->Lookup("main"));

    // Generate code for the updated function.
    graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
    graph_codegen_->Init(nullptr, targets_);
    graph_codegen_->Codegen(func);

    ret_.graph_json = graph_codegen_->GetJSON();
    ret_.params = graph_codegen_->GetParams();

    auto lowered_funcs = graph_codegen_->GetIRModule();

    // When there is no lowered_funcs due to reasons such as optimization.
    if (lowered_funcs.size() == 0) {
      // skip some code ......
      }
    } else {
      ret_.mod = tvm::build(lowered_funcs, target_host_);
    }

    Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
    // TODO(zhiics) We should be able to completely switch to MetadataModule no
    // matter whether there are external modules or not.
    if (!ext_mods.empty()) {
      ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods);
    }
  }

经过多番跳转,终于到达build的核心模块,再来看TVM逐步做的工作:

  1. 计算图优化:relay_module =Optimize(relay_module, targets_, params)
  2. 计算图生成
  3. 后端运行代码生成

计算图优化

调用的是成员函数Optimize,对计算图做设备无关的优化。

IRModule Optimize(IRModule relay_module, const TargetsMap& targets,
                    const std::unordered_map<std::string, runtime::NDArray>& params) {
    // skip some code ....

    Array<Pass> pass_seqs;
    Array<runtime::String> entry_functions{"main"};
    pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
    pass_seqs.push_back(transform::ToBasicBlockNormalForm());

    // Run all dialect legalization passes.
    pass_seqs.push_back(relay::qnn::transform::Legalize());

    // Legalize pass is restricted to homogeneous execution for now.
    if (targets.size() == 1) {
      pass_seqs.push_back(transform::Legalize());
    }

    pass_seqs.push_back(transform::SimplifyInference());
    // skip some code ....
    pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
    pass_seqs.push_back(transform::SimplifyExpr());
    pass_seqs.push_back(transform::CombineParallelConv2D(3));
    pass_seqs.push_back(transform::CombineParallelDense(3));
    pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
    pass_seqs.push_back(transform::FoldConstant());
    pass_seqs.push_back(transform::FoldScaleAxis());
    pass_seqs.push_back(transform::CanonicalizeCast());
    pass_seqs.push_back(transform::CanonicalizeOps());

    // Alter layout transformation is only applied to homogeneous execution yet.
    if (targets.size() == 1) {
      pass_seqs.push_back(transform::AlterOpLayout());
    }

    // Fast math optimizations.
    pass_seqs.push_back(transform::FastMath());
    pass_seqs.push_back(transform::FoldConstant());

    // Create a sequential pass and perform optimizations.
    transform::Pass seq = transform::Sequential(pass_seqs);
    if (targets.size() == 1) {
      const auto& it = targets.begin();
      With<Target> tctx((*it).second);
      relay_module = seq(relay_module);
    } else {
      relay_module = seq(relay_module);
    }

    // Handle heterogeneous compilation.
    transform::PassContext pass_ctx = PassContext::Current();
    if (targets_.size() > 1) {
      Optional<Integer> opt_fallback_dev =
          pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast<int>(kDLCPU)));
      auto fallback_dev = opt_fallback_dev.value();
      CHECK_GT(fallback_dev->value, 0U);
      relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value);
    }

    // Fuse the operations if it is needed.
    relay_module = transform::FuseOps()(relay_module);
    relay_module = transform::InferType()(relay_module);
    // Inline the functions that have been lifted by the module scope.
    
    relay_module = transform::Inline()(relay_module);
    CHECK(relay_module.defined());

    return relay_module;
  }

定义了Array<Pass> pass_seqs,运行不同的Pass操作,对计算图进行编译优化。

计算图生成

基于GraphCodegen类实现,对应BuildRelay函数代码为

 // Generate code for the updated function.
    graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
    graph_codegen_->Init(nullptr, targets_);
    graph_codegen_->Codegen(func);

    ret_.graph_json = graph_codegen_->GetJSON();
    ret_.params = graph_codegen_->GetParams();

    auto lowered_funcs = graph_codegen_->GetIRModule();

GraphCodegen类具体实现在tvm/src/relay/backend/http://build_module.cc中。

struct GraphCodegen {
 public:
  GraphCodegen() {
    auto pf = GetPackedFunc("relay.build_module._GraphRuntimeCodegen");
    mod = (*pf)();
  }
  ~GraphCodegen() {}

  void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); }

  void Codegen(const Function& func) { CallFunc("codegen", func); }

  std::string GetJSON() { return CallFunc<std::string>("get_graph_json", nullptr); }

  Array<tvm::runtime::Module> GetExternalModules() {
    return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
  }

  // skip some code ...
 protected:
  tvm::runtime::Module mod;
 // skip some code ...
};

然后实际调用的是tvm/src/relay/backend/http://graph_runtime_codegen.cc中注册的relay.build_module._GraphRuntimeCodegen方法。

runtime::Module CreateGraphCodegenMod() {
  auto ptr = make_object<GraphRuntimeCodegenModule>();
  return runtime::Module(ptr);
}

TVM_REGISTER_GLOBAL("relay.build_module._GraphRuntimeCodegen")
    .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); });

在GraphRuntimeCodegenModule类中有codegen方法,该方法实际调用的是GraphRuntimeCodegen类的Codegen方法最终生成计算图LoweredOutput Functions。

class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
 public:
  GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
    compile_engine_ = CompileEngine::Global();
    targets_ = targets;
  }

  LoweredOutput Codegen(relay::Function func) {
    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
    storage_device_map_ = (*pf)(func);
    // First we convert all the parameters into input nodes.
    for (auto param : func->params) {
      auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
      var_map_[param.get()] = AddNode(node_ptr, param);
    }
    heads_ = VisitExpr(func->body);
    std::ostringstream os;
    dmlc::JSONWriter writer(&os);
    GetJSON(&writer);
    LoweredOutput ret;
    ret.graph_json = os.str();
    ret.params = params_;

    for (auto& kv : lowered_funcs_) {
      if (ret.lowered_funcs.count(kv.first) == 0) {
        ret.lowered_funcs.Set(kv.first, IRModule());
      }
      auto& mod = ret.lowered_funcs[kv.first];
      mod->Update(kv.second);
      ret.lowered_funcs.Set(kv.first, mod);
    }
    ret.external_mods = compile_engine_->LowerExternalFunctions();
    return ret;
  }
  // skip some code....
}

后端代码生成

Relay得到lower后的函数,最后一步则是交给tvm::build做代码生成,跳转到tvm/src/driver/driver_api.cc中的build函数(注意这里重载了几个版本),然后跳转到核心build,注意这里的build函数支持异构编译,只要在inputs划分好不同硬件设施即可。

// Build for heterogeneous execution.
runtime::Module build(const Map<Target, IRModule>& inputs, const Target& target_host) {
  auto pass_ctx = transform::PassContext::Current();

  std::vector<runtime::Module> device_modules;
  Target target_host_val = target_host;
  // skip some code ...

  IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>());

  for (const auto& it : inputs) {
    auto pair = SplitDevHostFuncs(it.second, it.first, target_host_val, pass_ctx);
    auto& mhost = pair.first;
    auto& mdevice = pair.second;

    mhost_all->Update(mhost);
    if (mdevice->functions.size() != 0) {
      device_modules.push_back(codegen::Build(mdevice, it.first));
    }
  }

  runtime::Module mhost = codegen::Build(mhost_all, target_host_val);
  // Import all modules
  for (const auto& it : device_modules) {
    if (it.operator->()) {
      mhost.Import(it);
    }
  }
  return mhost;
}

当中最最核心的则是codegen::Build,最后跳转过去就开始调用代码生成模块(tvm/src/target/codegen.cc)。会根据后端设备名称调用已经注册的方法,

runtime::Module Build(IRModule mod, Target target) {
  // skip some code.....
  std::string build_f_name;
  if (target->kind->name == "micro_dev") {
    build_f_name = "target.build.c";
  } else {
    build_f_name = "target.build." + target->kind->name;
  }
  // the build function.
  const PackedFunc* bf = runtime::Registry::Get(build_f_name);
  CHECK(bf != nullptr) << build_f_name << " is not enabled";
  return (*bf)(mod, target);
}

以LLVM为例,target.build.llvm 会在tvm/src/target/llvm/llvm_module.cc注册,然后调用同个文件中的LLVMModuleNode->Init。这时会跳转到tvm/src/target/llvm/codegen_llvm.cc中的CodeGenLLVM类进行代码生成。

25ee80fba17525d4ee50b640236578d6.png

至此就完成了relay.build过程,生成后端可运行的代码。

以上只是粗略的熟悉下relay编译流程,其中有些代码没有细看,写的比较粗糙,后续有时间还需要再花时间梳理。

Reference

https://chhzh123.github.io/blogs/2020-03-26-tvm-flow/​chhzh123.github.io TVM - Tensor Expression​chhzh123.github.io TVM - Relay IR Pass​chhzh123.github.io
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值