为了与TVM进行对比,笔者决定同时看一下Tengine是如何做的。首先还是从图优化入手。
Tengine的整体架构
根据架构以及代码发现,Tengine将模型转换分离出来了,叫做Tengine-Convert-Tools
Tengine-Convert-Tools Github
有一说一,看的出来,Tengine的开发团队应该不是特别多,时间紧任务重,几乎所有的函数都没有注释(同时也没文档),虽然靠函数名也能猜出来,但是对于刚开始看的同学来说确实不太友好,对比来说,TVM或者任何其他的框架都做的要好很多。当然,baipiao就别有这么多想法了额。
Tengine Convert Tools
这个工具的功能也很简单,可以理解成onnx吧,作为一个中间键,后续Tengine的前端只需要解析tmfile就ok了。
我们还是看一下它做的图优化吧,
Tengine 图优化
Convert Tools的目的就是将原来框架的graph转换成Tengine的stactic graph。
基于此,Convert Tools中的图优化可以分为两个部分,
1. Serialize 优化
这一步是在由model创建graph的时候完成的。
graph = create_graph(nullptr, file_format.c_str(), model_file.c_str());
在create_graph函数中会调用:
vload_file_model(exec_context, model_name.c_str(), model_format, fname, argp)
其中,model_format代表转换的哪个框架,例如tf,onnx;fname代表转换的模型的名称,例如xxx.pb;model_name代表转换后的static graph的名字。
一直往下走的时候,一直到real load函数,才是根据不同的model_format来进行调用具体的load函数的。
static int real_vload_model(context_t exec_context, const char* model_name, const char* model_format, const void* addr,
int mem_size, va_list argp)
{
SerializerPtr serializer;
if (!SerializerManager::SafeGet(model_format, serializer))
{
/* try to load from plugin */
std::string plugin_fname = std::string("lib") + model_format + "-serializer.so";
std::string plugin_init_func = std::string(model_format) + "_plugin_init";
if (load_tengine_plugin(model_format, plugin_fname.c_str(), plugin_init_func.c_str()) < 0)
{
LOG_ERROR() << "Get serializer failed, unknown model format: " << model_format << "\n";
set_tengine_errno(ENOENT);
return -1;
}
SerializerManager::SafeGet(model_format, serializer);
}
StaticGraph* static_graph = CreateStaticGraph(model_name);
std::cout << "[step] [Test 2.2]" << std::endl;
static_graph->exec_context = exec_context;
int saved_file_number = serializer->GetFileNum();
if (mem_size == 0) // file mode
{
std::vector<std::string> file_list;
file_list.push_back(( const char* )addr);
for (int i = 1; i < saved_file_number; i++)
{
const char* file = va_arg(argp, const char*);
file_list.emplace_back(file);
}
if (!serializer->LoadModel(file_list, static_graph) || !CheckGraphIntegraity(static_graph))
{
delete static_graph;
return -1;
}
}
else
{
std::vector<const void*> addr_list;
std::vector<int> size_list;
addr_list.push_back(addr);
size_list.push_back(mem_size);
for (int i = 1; i < saved_file_number; i++)
{
addr = va_arg(argp, const void*);
mem_size = va_arg(argp, int);
addr_list.push_back(addr);
size_list.push_back(mem_size);
}
if (!serializer->LoadModel(addr_list, size_list, static_graph) || !CheckGraphIntegraity(static_graph))
{
delete static_graph;
return -1;
}
}
va_end(argp);
if (!StaticGraphManager::Add(std::string(model_name), StaticGraphPtr(static_graph)))
{
XLOG_ERROR() << "replicated model name detected: " << model_name << " should not happen\n";
set_tengine_errno(EBADSLT);
return -1;
}
return 0;
}
首先就是根据不同的model_format生成不同的继承于Serializer的类对象,这个是通过simpleobjectmanager作为维护的一个map来实现的,“tensorflow”就对应于tf_serializer类的对象。
同时为了线程安全,在每次去拿serialize的时候要先上锁,拿完之后再解锁,避免其他线程在该线程拿的时候对map进行了操作,使得拿到的数据不对。
拿完之后就会进行对应serialize的load model操作
serializer->LoadModel(addr_list, size_list, static_graph) || !CheckGraphIntegraity(static_graph)
我们去看下tf的函数如何做的
bool TFSerializer::LoadModel(const std::vector<std::string>& file_list, StaticGraph* graph)
{
tensorflow::GraphDef tf_net;
if ( //! LoadTextFile(file_list[0].c_str(), tf_net) &&
!LoadBinaryFile(file_list[0].c_str(), tf_net))
return false;
SetGraphSource(graph, file_list[0]);
SetGraphSourceFormat(graph, "tensorflow");
SetGraphConstTensorFile(graph, file_list[0]);
SetGraphLayout(graph, TENGINE_LAYOUT_NHWC);
SetModelLayout(graph, TENGINE_LAYOUT_NHWC);
SetModelFormat(graph, MODEL_FORMAT_TENSORFLOW);
return LoadGraph(tf_net, graph);
}
开始就是先将数据load到tf_net中,之后进行初始化的一些工作,然后将model load到StaticGraph中,
bool TFSerializer::LoadGraph(tensorflow::GraphDef& tf_net, StaticGraph* graph)
{
TFGraph tf_graph;
// step 1: construct whole graph
if (!ConstructGraph(tf_net, tf_graph))
return false;
if (!OptimizeRNN(tf_net, tf_graph))
return false;
// step 2: scanning and fusing nodes
if (!OptimizeGraph(tf_graph))
return false;
// step 3: create static graph
if (!GenerateStaticGraph(tf_graph, graph))
return false;
return true;
}
我们看到在step2的时候就会进行针对tf的一些图优化的操作,例如移除squeeze和identity;融合FIFOQueueV2等等。
对于其他的框架模型,例如onnx就没有图优化,
bool OnnxSerializer::LoadGraph(onnx::ModelProto& model, StaticGraph* graph)
{
const onnx::GraphProto& onnx_graph = model.graph();
SetGraphIdentity(graph, model.domain(), onnx_graph.name(), std::to_string(( int )model.model_version()));
LoadConstTensor(graph, onnx_graph);
CreateInputNode(graph, onnx_graph);
int i;
std::vector<std::string> no_supported_op;
for (i = 0; i < onnx_graph.node_size(); i++)
{
const onnx::NodeProto& onnx_node = onnx_graph.node(i);
const std::string& onnx_op_name = onnx_node.op_type();
if (!FindOpLoadMethod(onnx_op_name))
{
auto it = find(no_supported_op.begin(), no_supported_op.end(), onnx_op_name);
if (it == no_supported_op.end())
{
if (onnx_op_name == "Constant")
continue;
no_supported_op.push_back(onnx_op_name);
}
// LOG_ERROR() << "cannot find load function for operator: " << onnx_op_name << "\n";
// continue;
}
}
if (no_supported_op.size())
{
LOG_ERROR() << "These " << no_supported_op.size() << "op are not supported\n";
LOG_ERROR() << "{";
for (int j = 0; j < ( int )no_supported_op.size(); j++)
{
LOG_ERROR() << no_supported_op[j] << ",";
}
LOG_ERROR() << "}\n";
return false;
}
for(int i = 0; i < onnx_graph.node_size(); i++){
const onnx::NodeProto& onnx_node = onnx_graph.node(i);
const std::string& onnx_op_name = onnx_node.op_type();
if(onnx_op_name == "null" || onnx_op_name == "_zeros" || onnx_op_name == "constant")
continue;
std::vector<std::string>::iterator iter=std::find(support_op.begin(), support_op.end(), onnx_op_name);
if(iter==support_op.end()){
std::vector<std::string>::iterator uniter=std::find(unsupport_op.begin(), unsupport_op.end(), onnx_op_name);
if(uniter==unsupport_op.end()){
unsupport_op.push_back(onnx_op_name);
} else {
continue;
}
} else {
continue;
}
}
if(unsupport_op.size() != 0){
printf("These ops are not in onnx serializer: \n");
for(int i = 0; i < (int)unsupport_op.size(); i++){
printf("[ %s ]\n", unsupport_op[i].c_str());
}
printf("\n");
printf("You may need use onnx simplifier first\n");
return false;
}
for (i = 0; i < onnx_graph.node_size(); i++)
{
const onnx::NodeProto& onnx_node = onnx_graph.node(i);
const std::string& onnx_op_name = onnx_node.op_type();
if (onnx_op_name == "Constant")
continue;
StaticNode* node = CreateStaticNode(graph, onnx_node.output(0));
if (!LoadNode(graph, node, onnx_node))
break;
op_load_t op_func = any_cast<op_load_t>(GetOpLoadMethod(onnx_op_name));
if (!op_func(graph, node, onnx_node))
break;
}
if (i < onnx_graph.node_size())
return false;
return true;
}
可以看到,整个loadgraph函数就全部都是一些常规的转换操作,并没有图优化的操作,可能默认onnx转换之前已经进行过图优化了吧,毕竟有onnx-simplifier这种工具。
最终生成的graph结构体如下:
当将模型存到staticgraph之后,这里还不算完,作者这里估计只是将staticgraph当做一个单纯的保存graph信息的媒介。这样也便于manager管理,需要保存的东西不是那么多。当所有线程把全部的model都load完之后,将static graph转到真正的graph executor需要的graph形式(笔者理解这是由前端到后端的数据传递)
graph = Graph::CreateFromStatic(graph_name, static_graph);
if (static_graph != nullptr)
{
/* set dev handle */
exec_attr_.dev_handle = static_graph->dev_handle;
if (exec_context != static_graph->exec_context)
{
XLOG_INFO() << "create runtime graph from different context model\n";
}
}
然后给到GraphExecutor的graph_,并将GraphExecutor*作为最终的graph往后传(当然,数据什么的都是存在graphtask中的)。
2. Device优化
这部分是对前一部分生成的GraphExecutor进行prerun的时候进行的。
if (prerun_graph(graph) < 0)
{
std::cout << "prerun failed\n";
return -1;
}
后面的数据传输比较繁琐,看上去就是将所有的东西,比如graph,dev_engine等都放到graph_task中去,然后为了多线程不同的平台能一块跑,将dev的信息合并原来的graph task信息生成subgraph_task放到list中。
之后每一个subtask生成一个scheduler去preruntask。
后面具体跑到哪个dev上,就要根据不同的executor(都是继承自GenericDevExecutor)去进行图优化。
以Cpu为例,
bool CPUExecutor::DevOptimizeGraph(void* graph_handle)
{
return backend_dev_->OptimizeGraph(graph_handle);
}
bool OptimizeGraph(void* graph_handle)
{
return driver_->OptimizeGraph(this, graph_handle);
}
因为backend_dev_就是CPUDevice,所以就用CPU的drive去跑,
bool CPUDriver::OptimizeGraph(Device* dev, void* graph_handle)
{
DevContext* context = reinterpret_cast<DevContext*>(graph_handle);
return OptimizeGraph(dev, graph_handle, context->sub_graph);
}
bool CPUDriver::OptimizeGraph(Device* dev, void* graph_handle, Subgraph* graph)
{
DevContext* context = reinterpret_cast<DevContext*>(graph_handle);
CPUDevice* cpu_info = context->dev;
context->sub_graph = graph;
return cpu_info->RealOptimizeGraph(context, graph);
}
后面具体的device是CPUDdevice,调用了他的RealOptimizeGraph函数,去看下做了哪些优化:
bool RealOptimizeGraph(DevContext* context, Subgraph* graph)
{
context->optimized_graph = graph;
return backend_runner_.OptimizeGraph(context->optimized_graph);
}
bool CPURunner::OptimizeGraph(Subgraph* optimized_graph)
{
std::cout << " [step] [Device Optimize]" << std::endl;
GraphOptimizerManager::RunOpt("BNScale", optimized_graph);
GraphOptimizerManager::RunOpt("FcBn", optimized_graph);
GraphOptimizerManager::RunOpt("ConvBN", optimized_graph);
GraphOptimizerManager::RunOpt("ConvReLu", optimized_graph);
GraphOptimizerManager::RunOpt("ConvReLu6", optimized_graph);
return true;
}
这里就是device在做的优化,可以看到都是非常普通的图优化,没啥特别的。
总结
总的来说,Tengine做的图优化还是比较常规比较少的,但是为了能够多线程多平台去跑,感觉还是非常不错的框架,就是代码注释多点就好了,哈哈