本文主要介绍TVM针对不同后端部署平台生成运行代码的流程,TVM可以通过两种方式生成代码:tvm.build 和 relay.build。两种方式流程略有不同,tvm.build 主要针对单一算子进行编译优化,relay.build 是针对整个网络计算图进行编译优化。先介绍relay.build,示例代码如下所示。
relay.build
onnx_model = onnx.load('model/mobilenetv2.onnx')
mod,params = relay.frontend.from_onnx(onnx_model,shape_dict)
dtype = 'float32'
with relay.build_config(opt_level=0):
graph, lib, params = relay.build_module.build(mod, target, params=params)
对relay.build进行代码跟踪,首先进入tvm/python/tvm/relay/build_module.py(这里是因为在relay/__init__.py
中将build
函数直接import到relay的命名空间,因此跳过了build_module
这一层),其中的build
函数是build_module
内的全局函数)。
def build(mod, target=None, target_host=None, params=None, mod_name='default'):
// ignore some code.....
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
with tophub_context:
bld_mod = BuildModule()
graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, mod, mod_name, params)
return mod
首先是寻找AutoTVM是否有预先tune好的参数记录,然后构造tophub_context,在该域内创建BuildModule实例并调用build方法。BuildModule类实现在build_module.py中,部分代码如下。
class BuildModule(object):
"""Build an IR module to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
"""
def __init__(self):
self.mod = _build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._optimize = self.mod["optimize"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
def build(self, mod, target=None, target_host=None, params=None):
target = _update_target(target)
# Setup the params.
if params:
self._set_params(params)
# Build the IR module
self._build(mod, target, target_host)
# Get artifacts
graph_json = self.get_json()
mod = self.get_module()
params = self.get_params()
return graph_json, mod, params
而_build_module._BuildModule()又通过FFI在python/tvm/relay/_build_module.py中与C++函数建立联系。
"""The interface for building Relay functions exposed from C++."""
import tvm._ffi
tvm._ffi._init_api("relay.build_module", __name__)
对应的C++函数实现在tvm/src/relay/backend/http://build_module.cc中。
runtime::Module RelayBuildCreate() {
auto exec = make_object<RelayBuildModule>();
return runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RelayBuildCreate();
});
也就是注册一个RelayBuildModule供调用,可以在RelayBuildModule类中看build函数实现,也在http://build_module.cc文件中。
class RelayBuildModule : public runtime::ModuleNode {
public:
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
// ignore some code ....
} else if (name == "build") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
this->Build(args[0], args[1], args[2]);
});
// ignore some code ....
}
TVM对build函数做了一次封装返回一个PackedFunc,即RelayBuildModule类成员函数Build:this->Build(....)。
/*!
* brief Build relay IRModule for graph runtime
*
* param mod Relay IRModule
* param target Target device
* param target_host Host target device
*/
void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) {
targets_ = targets;
target_host_ = target_host;
BuildRelay(mod, params_);
// Clear compile engine so that tuning schedules can be changed between runs. See issue #6096.
CompileEngine::Global()->Clear();
}
会进一步调用成员函数BuildRelay。
void BuildRelay(IRModule relay_module,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
// Relay IRModule -> IRModule optimizations.
relay_module = Optimize(relay_module, targets_, params);
// Get the updated function.
auto func = Downcast<Function>(relay_module->Lookup("main"));
// Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
graph_codegen_->Init(nullptr, targets_);
graph_codegen_->Codegen(func);
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();
auto lowered_funcs = graph_codegen_->GetIRModule();
// When there is no lowered_funcs due to reasons such as optimization.
if (lowered_funcs.size() == 0) {
// skip some code ......
}
} else {
ret_.mod = tvm::build(lowered_funcs, target_host_);
}
Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
// TODO(zhiics) We should be able to completely switch to MetadataModule no
// matter whether there are external modules or not.
if (!ext_mods.empty()) {
ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods);
}
}
经过多番跳转,终于到达build
的核心模块,再来看TVM逐步做的工作:
- 计算图优化:relay_module =Optimize(relay_module, targets_, params)
- 计算图生成
- 后端运行代码生成
计算图优化
调用的是成员函数Optimize,对计算图做设备无关的优化。
IRModule Optimize(IRModule relay_module, const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
// skip some code ....
Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
pass_seqs.push_back(transform::ToBasicBlockNormalForm());
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}
pass_seqs.push_back(transform::SimplifyInference());
// skip some code ....
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());
// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
}
// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());
// Create a sequential pass and perform optimizations.
transform::Pass seq = transform::Sequential(pass_seqs);
if (targets.size() == 1) {
const auto& it = targets.begin();
With<Target> tctx((*it).second);
relay_module = seq(relay_module);
} else {
relay_module = seq(relay_module);
}
// Handle heterogeneous compilation.
transform::PassContext pass_ctx = PassContext::Current();
if (targets_.size() > 1) {
Optional<Integer> opt_fallback_dev =
pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast<int>(kDLCPU)));
auto fallback_dev = opt_fallback_dev.value();
CHECK_GT(fallback_dev->value, 0U);
relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value);
}
// Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module);
// Inline the functions that have been lifted by the module scope.
relay_module = transform::Inline()(relay_module);
CHECK(relay_module.defined());
return relay_module;
}
定义了Array<Pass> pass_seqs,运行不同的Pass操作,对计算图进行编译优化。
计算图生成
基于GraphCodegen类实现,对应BuildRelay函数代码为
// Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
graph_codegen_->Init(nullptr, targets_);
graph_codegen_->Codegen(func);
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();
auto lowered_funcs = graph_codegen_->GetIRModule();
GraphCodegen类具体实现在tvm/src/relay/backend/http://build_module.cc中。
struct GraphCodegen {
public:
GraphCodegen() {
auto pf = GetPackedFunc("relay.build_module._GraphRuntimeCodegen");
mod = (*pf)();
}
~GraphCodegen() {}
void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); }
void Codegen(const Function& func) { CallFunc("codegen", func); }
std::string GetJSON() { return CallFunc<std::string>("get_graph_json", nullptr); }
Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
}
// skip some code ...
protected:
tvm::runtime::Module mod;
// skip some code ...
};
然后实际调用的是tvm/src/relay/backend/http://graph_runtime_codegen.cc中注册的relay.build_module._GraphRuntimeCodegen方法。
runtime::Module CreateGraphCodegenMod() {
auto ptr = make_object<GraphRuntimeCodegenModule>();
return runtime::Module(ptr);
}
TVM_REGISTER_GLOBAL("relay.build_module._GraphRuntimeCodegen")
.set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); });
在GraphRuntimeCodegenModule类中有codegen方法,该方法实际调用的是GraphRuntimeCodegen类的Codegen方法最终生成计算图LoweredOutput Functions。
class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
public:
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
compile_engine_ = CompileEngine::Global();
targets_ = targets;
}
LoweredOutput Codegen(relay::Function func) {
auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
storage_device_map_ = (*pf)(func);
// First we convert all the parameters into input nodes.
for (auto param : func->params) {
auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
var_map_[param.get()] = AddNode(node_ptr, param);
}
heads_ = VisitExpr(func->body);
std::ostringstream os;
dmlc::JSONWriter writer(&os);
GetJSON(&writer);
LoweredOutput ret;
ret.graph_json = os.str();
ret.params = params_;
for (auto& kv : lowered_funcs_) {
if (ret.lowered_funcs.count(kv.first) == 0) {
ret.lowered_funcs.Set(kv.first, IRModule());
}
auto& mod = ret.lowered_funcs[kv.first];
mod->Update(kv.second);
ret.lowered_funcs.Set(kv.first, mod);
}
ret.external_mods = compile_engine_->LowerExternalFunctions();
return ret;
}
// skip some code....
}
后端代码生成
Relay得到lower后的函数,最后一步则是交给tvm::build
做代码生成,跳转到tvm/src/driver/driver_api.cc中的build
函数(注意这里重载了几个版本),然后跳转到核心build
,注意这里的build
函数支持异构编译,只要在inputs
划分好不同硬件设施即可。
// Build for heterogeneous execution.
runtime::Module build(const Map<Target, IRModule>& inputs, const Target& target_host) {
auto pass_ctx = transform::PassContext::Current();
std::vector<runtime::Module> device_modules;
Target target_host_val = target_host;
// skip some code ...
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>());
for (const auto& it : inputs) {
auto pair = SplitDevHostFuncs(it.second, it.first, target_host_val, pass_ctx);
auto& mhost = pair.first;
auto& mdevice = pair.second;
mhost_all->Update(mhost);
if (mdevice->functions.size() != 0) {
device_modules.push_back(codegen::Build(mdevice, it.first));
}
}
runtime::Module mhost = codegen::Build(mhost_all, target_host_val);
// Import all modules
for (const auto& it : device_modules) {
if (it.operator->()) {
mhost.Import(it);
}
}
return mhost;
}
当中最最核心的则是codegen::Build
,最后跳转过去就开始调用代码生成模块(tvm/src/target/codegen.cc)。会根据后端设备名称调用已经注册的方法,
runtime::Module Build(IRModule mod, Target target) {
// skip some code.....
std::string build_f_name;
if (target->kind->name == "micro_dev") {
build_f_name = "target.build.c";
} else {
build_f_name = "target.build." + target->kind->name;
}
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr) << build_f_name << " is not enabled";
return (*bf)(mod, target);
}
以LLVM为例,target.build.llvm 会在tvm/src/target/llvm/llvm_module.cc
注册,然后调用同个文件中的LLVMModuleNode->Init
。这时会跳转到tvm/src/target/llvm/codegen_llvm.cc
中的CodeGenLLVM
类进行代码生成。
![25ee80fba17525d4ee50b640236578d6.png](https://img-blog.csdnimg.cn/img_convert/25ee80fba17525d4ee50b640236578d6.png)
至此就完成了relay.build过程,生成后端可运行的代码。
以上只是粗略的熟悉下relay编译流程,其中有些代码没有细看,写的比较粗糙,后续有时间还需要再花时间梳理。