使用protobuf解析Onnx文件

使用OpenCV加载Onnx推理的时候,无法获取到Onnx的网络输入大小,并且对推理速度要求不高不需要使用TensorRT的时候,如何才能得知Onnx的一些必要的信息,OpenCV没有提供接口,只能自己从Onnx文件中解析了。

Onnx文件是使用protobuf序列化后的二进制数据,想要读取里面的信息需要使用protobuf将其反序列化为对象才行。

第一步:

编译 protobuf,protocolbuffers/protobuf: Protocol Buffers - Google's data interchange format (github.com)

使用cmake生成vs工程直接编译即可。

第二步:

使用protoc命令,将编写好的proto文件生成C++类定义文件。

.\protoc.exe --cpp_out=. ./onnx.proto

得到 onnx.pb.h 和 onnx.pb.cpp 两个文件,将其加入到工程中。

Onnx.proto 文件可以从 onnx/onnx: Open standard for machine learning interoperability (github.com) 找到。

使用编译好的库文件编译的过程中,如果出现 无法解析的外部符号 "class google::protobuf::internal::ExplicitlyConstructed fixed_address_empty_string" 这个错误,添加预处理宏定义 PROTOBUF_USE_DLLS 即可。

关键代码:

加载反序列化onnx文件

    ifstream fin("model.onnx", std::ios::in | std::ios::binary);
    onnx::ModelProto onnx_model;
    onnx_model.ParseFromIstream(&fin);

打印一些信息   

    std::cout << "ir_version: " << onnx_model.ir_version() << std::endl;
    std::cout << "opset_import_size: " << onnx_model.opset_import_size() << std::endl;
    std::cout << "OperatorSetIdProto domain: " << onnx_model.opset_import(0).domain() << std::endl;
    std::cout << "OperatorSetIdProto version: " << onnx_model.opset_import(0).version() << std::endl;
    std::cout << "producer_name: " << onnx_model.producer_name() << std::endl;
    std::cout << "producer_version: " << onnx_model.producer_version() << std::endl;
    std::cout << "domain: " << onnx_model.domain() << std::endl;
    std::cout << "model_version: " << onnx_model.model_version() << std::endl;
    std::cout << "doc_string: " << onnx_model.doc_string() << std::endl;

输入节点的个数

 

onnx_model.graph().input_size()

输入节点的名称

 

 onnx_model.graph().input(0).name()

输入输出节点的数据类型

onnx_model.graph().input(0).type().tensor_type().elem_type()

返回的是int类型,与实际数据类型对应关系见:onnx/onnx.proto at main · onnx/onnx (github.com) 文件中的 DataType 枚举类型。

输入节点的输入维度

int dim_size = onnx_model.graph().input(0).type().tensor_type().shape().dim_size();
for (int i = 0; i < dim_size; i++)
{
    onnx_model.graph().input(0).type().tensor_type().shape().dim().Get(i).dim_value();
}

封装类 OnnxInfo

OnnxInfo.h


#ifdef __cplusplus
extern "C" {
#endif

    DLL_API uint64_t onnx_load(const char* path);

    DLL_API void onnx_close(uint64_t ptr_addr);

    DLL_API int onnx_get_input_count(uint64_t ptr_addr);

    DLL_API int onnx_get_output_count(uint64_t ptr_addr);

    DLL_API const char* onnx_get_input_name(uint64_t ptr_addr, int input_index);

    DLL_API const char* onnx_get_output_name(uint64_t ptr_addr, int output_index);

    DLL_API const char* onnx_get_input_data_type(uint64_t ptr_addr, int input_index);

    DLL_API const char* onnx_get_output_data_type(uint64_t ptr_addr, int output_index);

    DLL_API int onnx_get_input_dims(uint64_t ptr_addr, int input_index, int* dims);

    DLL_API int onnx_get_output_dims(uint64_t ptr_addr, int output_index, int* dims);

#ifdef __cplusplus
}
#endif

OnnxInfo.cpp

#define _DLL_INTERNEL_

#include "OnnxInfo.h"
#include "onnx.pb.h"

#include <fstream>
#include <memory>
#include <vector>



using namespace std;

#ifdef __cplusplus
extern "C" {
#endif

    namespace {

        char g_string[256];

        vector<shared_ptr<onnx::ModelProto>> g_models;

        void g_release_ptr(uint64_t ptr_addr)
        {
            auto iter = g_models.begin();
            while (iter != g_models.end())
            {
                if (ptr_addr == (uint64_t)iter->get())
                {
                    g_models.erase(iter);
                    break;
                }
                iter++;
            }
        }

        shared_ptr<onnx::ModelProto> g_get_ptr(uint64_t ptr_addr)
        {
            for (auto& ptr : g_models)
            {
                if (ptr_addr == (uint64_t)ptr.get())
                {
                    return ptr;
                }
            }
            return nullptr;
        }

        const char* g_get_data_type_name_by_id(int data_type_id)
        {
            auto dt = (onnx::TensorProto_DataType)data_type_id;
            switch (dt)
            {
            case onnx::TensorProto_DataType_UNDEFINED:
                return "UNDEFINED";
            case onnx::TensorProto_DataType_FLOAT:
                return "FLOAT";
            case onnx::TensorProto_DataType_UINT8:
                return "UINT8";
            case onnx::TensorProto_DataType_INT8:
                return "FLOAT";
            case onnx::TensorProto_DataType_UINT16:
                return "UINT16";
            case onnx::TensorProto_DataType_INT16:
                return "INT16";
            case onnx::TensorProto_DataType_INT32:
                return "INT32";
            case onnx::TensorProto_DataType_INT64:
                return "INT64";
            case onnx::TensorProto_DataType_STRING:
                return "STRING";
            case onnx::TensorProto_DataType_BOOL:
                return "BOOL";
            case onnx::TensorProto_DataType_FLOAT16:
                return "FLOAT16";
            case onnx::TensorProto_DataType_DOUBLE:
                return "DOUBLE";
            case onnx::TensorProto_DataType_UINT32:
                return "UINT32";
            case onnx::TensorProto_DataType_UINT64:
                return "UINT64";
            case onnx::TensorProto_DataType_COMPLEX64:
                return "COMPLEX64";
            case onnx::TensorProto_DataType_COMPLEX128:
                return "COMPLEX128";
            case onnx::TensorProto_DataType_BFLOAT16:
                return "BFLOAT16";
            default:
                return "";
            }
        }
    }
    
    /// <summary> 加载onnx </summary>
    uint64_t onnx_load(const char* path)
    {
        ifstream fin(path, std::ios::in | std::ios::binary);

        auto onnx_model_ptr = make_shared<onnx::ModelProto>();

        bool bret = onnx_model_ptr->ParseFromIstream(&fin);
        fin.close();

        if (!bret) return 0;

        uint64_t ptr = (uint64_t)onnx_model_ptr.get();
        g_models.push_back(move(onnx_model_ptr));
        return ptr;
    }

    /// <summary> 关闭onnx </summary>
    void onnx_close(uint64_t ptr_addr)
    {
        g_release_ptr(ptr_addr);
    }

    /// <summary> 获取输入节点的个数 </summary>
    int onnx_get_input_count(uint64_t ptr_addr)
    {
        auto ptr = g_get_ptr(ptr_addr);
        if (!ptr.get()) return -1;

        if (!ptr->has_graph()) return -2;

        return ptr->graph().input_size();
    }

    /// <summary> 获取输出节点的个数 </summary>
    int onnx_get_output_count(uint64_t ptr_addr)
    {
        auto ptr = g_get_ptr(ptr_addr);
        if (!ptr.get()) return -1;

        if (!ptr->has_graph()) return -2;

        return ptr->graph().output_size();
    }


    /// <summary> 获取输入节点的名称 </summary>
    const char* onnx_get_input_name(uint64_t ptr_addr, int input_index)
    {
        auto ptr = g_get_ptr(ptr_addr);
        if (!ptr.get()) return "";

        if (!ptr->has_graph()) return "";

        int input_size = ptr->graph().input_size();
        if (input_index >= input_size || input_index < 0) return "";

        auto input = ptr->graph().input(input_index);

        sprintf_s(g_string, sizeof(g_string), input.name().c_str());

        return g_string;
    }

    /// <summary> 获取输出节点的名称 </summary>
    const char* onnx_get_output_name(uint64_t ptr_addr, int output_index)
    {
        auto ptr = g_get_ptr(ptr_addr);
        if (!ptr.get()) return "";

        if (!ptr->has_graph()) return "";

        int output_size = ptr->graph().output_size();
        if (output_index >= output_size || output_index < 0) return "";

        auto output = ptr->graph().output(output_index);

        sprintf_s(g_string, sizeof(g_string), output.name().c_str());

        return g_string;
    }

    /// <summary> 获取输入节点的数据类型 </summary>
    const char* onnx_get_input_data_type(uint64_t ptr_addr, int input_index)
    {
        auto ptr = g_get_ptr(ptr_addr);
        if (!ptr.get()) return "";

        if (!ptr->has_graph()) return "";

        int input_size = ptr->graph().input_size();
        if (input_index >= input_size || input_index < 0) return "";

        auto input = ptr->graph().input(input_index);

        auto type_id = input.type().tensor_type().elem_type();

        return g_get_data_type_name_by_id(type_id);
    }

    /// <summary> 获取输出节点的数据类型 </summary>
    const char* onnx_get_output_data_type(uint64_t ptr_addr, int output_index)
    {
        auto ptr = g_get_ptr(ptr_addr);
        if (!ptr.get()) return "";

        if (!ptr->has_graph()) return "";

        int output_size = ptr->graph().output_size();
        if (output_index >= output_size || output_index < 0) return "";

        auto output = ptr->graph().output(output_index);

        auto type_id = output.type().tensor_type().elem_type();

        return g_get_data_type_name_by_id(type_id);
    }

    /// <summary> 获取输入节点的维数 </summary>
    int onnx_get_input_dims(uint64_t ptr_addr, int input_index, int* dims)
    {
        auto ptr = g_get_ptr(ptr_addr);
        if (!ptr.get()) return -1;

        if (!ptr->has_graph()) return -2;

        int input_size = ptr->graph().input_size();
        if (input_index >= input_size || input_index < 0) return -3;

        auto input = ptr->graph().input(input_index);

        int dim_size = input.type().tensor_type().shape().dim_size();
        

        if (dims) for (int i = 0; i < dim_size; i++)
        {
            dims[i] = input.type().tensor_type().shape().dim().Get(i).dim_value();
        }

        return dim_size;
    }

    /// <summary> 获取输出节点的维数 </summary>
    int onnx_get_output_dims(uint64_t ptr_addr, int output_index, int* dims)
    {
        auto ptr = g_get_ptr(ptr_addr);
        if (!ptr.get()) return -1;

        if (!ptr->has_graph()) return -2;

        int output_size = ptr->graph().output_size();
        if (output_index >= output_size || output_index < 0) return -3;

        auto output = ptr->graph().output(output_index);

        int dim_size = output.type().tensor_type().shape().dim_size();

        if (dims) for (int i = 0; i < dim_size; i++)
        {
            dims[i] = output.type().tensor_type().shape().dim().Get(i).dim_value();
        }

        return dim_size;
    }



#ifdef __cplusplus
}
#endif

调用方法

int main()
{
    const char* path = "myunet.onnx";

    auto adr = onnx_load(path);

    int ic = onnx_get_input_count(adr);
    int oc = onnx_get_output_count(adr);

    cout << onnx_get_input_name(adr, 0) << endl;
    cout << onnx_get_output_name(adr, 0) << endl;

    cout << onnx_get_input_data_type(adr, 0) << endl;
    cout << onnx_get_output_data_type(adr, 0) << endl;

    int dim_size;
    int dims[8];

    cout << "input dim size: " << onnx_get_input_dims(adr, 0, 0) << endl;
    cout << "input dim size2: " << (dim_size = onnx_get_input_dims(adr, 0, dims)) << endl;
    for (int i = 0; i < dim_size; i++)
    {
        cout << dims[i] << " ";
    }
    cout << endl;

    cout << "output dim size: " << onnx_get_output_dims(adr, 0, 0) << endl;
    cout << "output dim size2: " << (dim_size = onnx_get_output_dims(adr, 0, dims)) << endl;
    for (int i = 0; i < dim_size; i++)
    {
        cout << dims[i] << " ";
    }
    cout << endl;

    onnx_close(adr);

    cin.ignore();

    return 0;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ango_Cango

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值