Tengine Convert Tools代码走读-图优化篇

为了与TVM进行对比,笔者决定同时看一下Tengine是如何做的。首先还是从图优化入手。

Tengine的整体架构

Tengine整体代码架构
根据架构以及代码发现,Tengine将模型转换分离出来了,叫做Tengine-Convert-Tools
Tengine-Convert-Tools Github
有一说一,看的出来,Tengine的开发团队应该不是特别多,时间紧任务重,几乎所有的函数都没有注释(同时也没文档),虽然靠函数名也能猜出来,但是对于刚开始看的同学来说确实不太友好,对比来说,TVM或者任何其他的框架都做的要好很多。当然,baipiao就别有这么多想法了额。

Tengine Convert Tools

这个工具的功能也很简单,可以理解成onnx吧,作为一个中间键,后续Tengine的前端只需要解析tmfile就ok了。
我们还是看一下它做的图优化吧,

Tengine 图优化

Convert Tools的目的就是将原来框架的graph转换成Tengine的stactic graph。
基于此,Convert Tools中的图优化可以分为两个部分,

1. Serialize 优化

这一步是在由model创建graph的时候完成的。

graph = create_graph(nullptr, file_format.c_str(), model_file.c_str());

在create_graph函数中会调用:

vload_file_model(exec_context, model_name.c_str(), model_format, fname, argp)

其中,model_format代表转换的哪个框架,例如tf,onnx;fname代表转换的模型的名称,例如xxx.pb;model_name代表转换后的static graph的名字。
一直往下走的时候,一直到real load函数,才是根据不同的model_format来进行调用具体的load函数的。

static int real_vload_model(context_t exec_context, const char* model_name, const char* model_format, const void* addr,
                            int mem_size, va_list argp)
{
    SerializerPtr serializer;

    if (!SerializerManager::SafeGet(model_format, serializer))
    {
        /* try to load from plugin */
        std::string plugin_fname = std::string("lib") + model_format + "-serializer.so";
        std::string plugin_init_func = std::string(model_format) + "_plugin_init";
        if (load_tengine_plugin(model_format, plugin_fname.c_str(), plugin_init_func.c_str()) < 0)
        {
            LOG_ERROR() << "Get serializer failed, unknown model format: " << model_format << "\n";
            set_tengine_errno(ENOENT);
            return -1;
        }

        SerializerManager::SafeGet(model_format, serializer);
    }

    StaticGraph* static_graph = CreateStaticGraph(model_name);
    std::cout << "[step] [Test 2.2]" << std::endl;
    static_graph->exec_context = exec_context;

    int saved_file_number = serializer->GetFileNum();

    if (mem_size == 0)    // file mode
    {
        std::vector<std::string> file_list;
        file_list.push_back(( const char* )addr);

        for (int i = 1; i < saved_file_number; i++)
        {
            const char* file = va_arg(argp, const char*);
            file_list.emplace_back(file);
        }

        if (!serializer->LoadModel(file_list, static_graph) || !CheckGraphIntegraity(static_graph))
        {
            delete static_graph;
            return -1;
        }
    }
    else
    {
        std::vector<const void*> addr_list;
        std::vector<int> size_list;

        addr_list.push_back(addr);
        size_list.push_back(mem_size);

        for (int i = 1; i < saved_file_number; i++)
        {
            addr = va_arg(argp, const void*);
            mem_size = va_arg(argp, int);

            addr_list.push_back(addr);
            size_list.push_back(mem_size);
        }

        if (!serializer->LoadModel(addr_list, size_list, static_graph) || !CheckGraphIntegraity(static_graph))
        {
            delete static_graph;
            return -1;
        }
    }
    va_end(argp);

    if (!StaticGraphManager::Add(std::string(model_name), StaticGraphPtr(static_graph)))
    {
        XLOG_ERROR() << "replicated model name detected: " << model_name << " should not happen\n";
        set_tengine_errno(EBADSLT);
        return -1;
    }
    return 0;
}

首先就是根据不同的model_format生成不同的继承于Serializer的类对象,这个是通过simpleobjectmanager作为维护的一个map来实现的,“tensorflow”就对应于tf_serializer类的对象。
同时为了线程安全,在每次去拿serialize的时候要先上锁,拿完之后再解锁,避免其他线程在该线程拿的时候对map进行了操作,使得拿到的数据不对。
拿完之后就会进行对应serialize的load model操作

serializer->LoadModel(addr_list, size_list, static_graph) || !CheckGraphIntegraity(static_graph)

我们去看下tf的函数如何做的

bool TFSerializer::LoadModel(const std::vector<std::string>& file_list, StaticGraph* graph)
{
    tensorflow::GraphDef tf_net;

    if (    //! LoadTextFile(file_list[0].c_str(), tf_net) &&
        !LoadBinaryFile(file_list[0].c_str(), tf_net))
        return false;

    SetGraphSource(graph, file_list[0]);
    SetGraphSourceFormat(graph, "tensorflow");
    SetGraphConstTensorFile(graph, file_list[0]);
    SetGraphLayout(graph, TENGINE_LAYOUT_NHWC);
    SetModelLayout(graph, TENGINE_LAYOUT_NHWC);
    SetModelFormat(graph, MODEL_FORMAT_TENSORFLOW);

    return LoadGraph(tf_net, graph);
}

开始就是先将数据load到tf_net中,之后进行初始化的一些工作,然后将model load到StaticGraph中,

bool TFSerializer::LoadGraph(tensorflow::GraphDef& tf_net, StaticGraph* graph)
{
    TFGraph tf_graph;

    // step 1: construct whole graph

    if (!ConstructGraph(tf_net, tf_graph))
        return false;

    if (!OptimizeRNN(tf_net, tf_graph))
        return false;

    // step 2: scanning and fusing nodes

    if (!OptimizeGraph(tf_graph))
        return false;

    // step 3: create static graph
    if (!GenerateStaticGraph(tf_graph, graph))
        return false;

    return true;
}

我们看到在step2的时候就会进行针对tf的一些图优化的操作,例如移除squeeze和identity;融合FIFOQueueV2等等。
对于其他的框架模型,例如onnx就没有图优化,

bool OnnxSerializer::LoadGraph(onnx::ModelProto& model, StaticGraph* graph)
{
    const onnx::GraphProto& onnx_graph = model.graph();

    SetGraphIdentity(graph, model.domain(), onnx_graph.name(), std::to_string(( int )model.model_version()));

    LoadConstTensor(graph, onnx_graph);
    CreateInputNode(graph, onnx_graph);

    int i;
    std::vector<std::string> no_supported_op;
    for (i = 0; i < onnx_graph.node_size(); i++)
    {
        const onnx::NodeProto& onnx_node = onnx_graph.node(i);
        const std::string& onnx_op_name = onnx_node.op_type();

        if (!FindOpLoadMethod(onnx_op_name))
        {
            auto it = find(no_supported_op.begin(), no_supported_op.end(), onnx_op_name);
            if (it == no_supported_op.end())
            {
                if (onnx_op_name == "Constant")
                    continue;
                no_supported_op.push_back(onnx_op_name);
            }
            //       LOG_ERROR() << "cannot find load function for operator: " << onnx_op_name << "\n";
            //       continue;
        }
    }
    if (no_supported_op.size())
    {
        LOG_ERROR() << "These " << no_supported_op.size() << "op are not supported\n";
        LOG_ERROR() << "{";
        for (int j = 0; j < ( int )no_supported_op.size(); j++)
        {
            LOG_ERROR() << no_supported_op[j] << ",";
        }
        LOG_ERROR() << "}\n";

        return false;
    }
    for(int i = 0; i < onnx_graph.node_size(); i++){
        const onnx::NodeProto& onnx_node = onnx_graph.node(i);
        const std::string& onnx_op_name = onnx_node.op_type();
        if(onnx_op_name == "null" || onnx_op_name == "_zeros" || onnx_op_name == "constant")
            continue; 

        std::vector<std::string>::iterator iter=std::find(support_op.begin(), support_op.end(), onnx_op_name);
        if(iter==support_op.end()){
            std::vector<std::string>::iterator uniter=std::find(unsupport_op.begin(), unsupport_op.end(), onnx_op_name);
            if(uniter==unsupport_op.end()){
                unsupport_op.push_back(onnx_op_name);
            } else {
                continue;
            }
        } else {
            continue;
        }
    }
    if(unsupport_op.size() != 0){
        printf("These ops are not in onnx serializer: \n");
        for(int i = 0; i < (int)unsupport_op.size(); i++){
            printf("[ %s ]\n", unsupport_op[i].c_str());
        }
        printf("\n");
        printf("You may need use onnx simplifier first\n");
        return false;
    }
    for (i = 0; i < onnx_graph.node_size(); i++)
    {
        const onnx::NodeProto& onnx_node = onnx_graph.node(i);
        const std::string& onnx_op_name = onnx_node.op_type();

        if (onnx_op_name == "Constant")
            continue;
        StaticNode* node = CreateStaticNode(graph, onnx_node.output(0));

        if (!LoadNode(graph, node, onnx_node))
            break;

        op_load_t op_func = any_cast<op_load_t>(GetOpLoadMethod(onnx_op_name));

        if (!op_func(graph, node, onnx_node))
            break;
    }

    if (i < onnx_graph.node_size())
        return false;

    return true;
}

可以看到,整个loadgraph函数就全部都是一些常规的转换操作,并没有图优化的操作,可能默认onnx转换之前已经进行过图优化了吧,毕竟有onnx-simplifier这种工具。
最终生成的graph结构体如下:
static graph定义
当将模型存到staticgraph之后,这里还不算完,作者这里估计只是将staticgraph当做一个单纯的保存graph信息的媒介。这样也便于manager管理,需要保存的东西不是那么多。当所有线程把全部的model都load完之后,将static graph转到真正的graph executor需要的graph形式(笔者理解这是由前端到后端的数据传递)

graph = Graph::CreateFromStatic(graph_name, static_graph);

if (static_graph != nullptr)
{
    /* set dev handle */
    exec_attr_.dev_handle = static_graph->dev_handle;

    if (exec_context != static_graph->exec_context)
    {
        XLOG_INFO() << "create runtime graph from different context model\n";
    }
}

然后给到GraphExecutor的graph_,并将GraphExecutor*作为最终的graph往后传(当然,数据什么的都是存在graphtask中的)。

2. Device优化

这部分是对前一部分生成的GraphExecutor进行prerun的时候进行的。

if (prerun_graph(graph) < 0)
{
    std::cout << "prerun failed\n";
    return -1;
}

后面的数据传输比较繁琐,看上去就是将所有的东西,比如graph,dev_engine等都放到graph_task中去,然后为了多线程不同的平台能一块跑,将dev的信息合并原来的graph task信息生成subgraph_task放到list中。
之后每一个subtask生成一个scheduler去preruntask。
后面具体跑到哪个dev上,就要根据不同的executor(都是继承自GenericDevExecutor)去进行图优化。
以Cpu为例,

bool CPUExecutor::DevOptimizeGraph(void* graph_handle)
{
    return backend_dev_->OptimizeGraph(graph_handle);
}
bool OptimizeGraph(void* graph_handle)
{
    return driver_->OptimizeGraph(this, graph_handle);
}

因为backend_dev_就是CPUDevice,所以就用CPU的drive去跑,

bool CPUDriver::OptimizeGraph(Device* dev, void* graph_handle)
{
    DevContext* context = reinterpret_cast<DevContext*>(graph_handle);

    return OptimizeGraph(dev, graph_handle, context->sub_graph);
}
bool CPUDriver::OptimizeGraph(Device* dev, void* graph_handle, Subgraph* graph)
{
    DevContext* context = reinterpret_cast<DevContext*>(graph_handle);
    CPUDevice* cpu_info = context->dev;

    context->sub_graph = graph;

    return cpu_info->RealOptimizeGraph(context, graph);
}

后面具体的device是CPUDdevice,调用了他的RealOptimizeGraph函数,去看下做了哪些优化:

bool RealOptimizeGraph(DevContext* context, Subgraph* graph)
{
    context->optimized_graph = graph;
    return backend_runner_.OptimizeGraph(context->optimized_graph);
}
bool CPURunner::OptimizeGraph(Subgraph* optimized_graph)
{
    std::cout << " [step] [Device Optimize]" << std::endl;
    GraphOptimizerManager::RunOpt("BNScale", optimized_graph);
    GraphOptimizerManager::RunOpt("FcBn", optimized_graph);
    GraphOptimizerManager::RunOpt("ConvBN", optimized_graph);
    GraphOptimizerManager::RunOpt("ConvReLu", optimized_graph);
    GraphOptimizerManager::RunOpt("ConvReLu6", optimized_graph);

    return true;
}

这里就是device在做的优化,可以看到都是非常普通的图优化,没啥特别的。

总结

总的来说,Tengine做的图优化还是比较常规比较少的,但是为了能够多线程多平台去跑,感觉还是非常不错的框架,就是代码注释多点就好了,哈哈

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值