caffe模型文件解析_MNN模型转化的主流程

本文详细介绍了如何将Caffe模型转换为MNN格式,包括模型转换的三个主要步骤:参数解析、模型转换和写模型过程。在模型转换过程中,重点讨论了ONNX模型的输入输出格式和转化过程,以及在转化过程中涉及到的对象、算子转化和优化策略。文章还提及了MNN模型优化的前后优化阶段,以及模型文件的写入。
摘要由CSDN通过智能技术生成

06f6ab1deb24ebd75cf712f47d793f35.png

模型转换MNNConvert

从MNN官方文档看到, 模型转化工具是是一个MNNConverter的bin工具, 支持把tensorflow, onnx,caffe模型转化给MNN定义的模型.

直接进入tools/converter目录, 入口函数在MNNConverter.cpp的main函数. 主函数分为3个大步骤:

  • 参数解析
  • 读取模型、解析、转换
  • 生成新的MNN格式模型
 int main(int argc, char *argv[]) {
    
     modelConfig modelPath;
     Cli::initializeMNNConvertArgs(modelPath, argc, argv); // 解析命令行参数,存到modelConfig类里
     // 根据输入模型的类型,调用不同的 转换函数
     std::unique_ptr<MNN::NetT> netT = std::unique_ptr<MNN::NetT>(new MNN::NetT());
     if (modelPath.model == modelConfig::CAFFE) {
    
         caffe2MNNNet(modelPath.prototxtFile, modelPath.modelFile, modelPath.bizCode, netT);
     } else if (modelPath.model == modelConfig::TENSORFLOW) {
    
         tensorflow2MNNNet(modelPath.modelFile, modelPath.bizCode, netT);
     } else if (modelPath.model == modelConfig::MNN) {
    
         addBizCode(modelPath.modelFile, modelPath.bizCode, netT);
     } else if (modelPath.model == modelConfig::ONNX) {
    
         onnx2MNNNet(modelPath.modelFile, modelPath.bizCode, netT);
     } else if (modelPath.model == modelConfig::TFLITE) {
    
         tflite2MNNNet(modelPath.modelFile, modelPath.bizCode, netT);
     } else {
    
         std::cout << "Not Support Model Type" << std::endl;
     }
     // 转换成功后, 写到新的MNN文件里
     if (modelPath.model != modelConfig::MNN) {
    
         std::cout << "Start to Optimize the MNN Net..." << std::endl;
         std::unique_ptr<MNN::NetT> newNet = optimizeNet(netT, modelPath.forTraining);
         writeFb(newNet, modelPath.MNNModel, modelPath.benchmarkModel, modelPath.saveHalfFloat);
     } else {
    
         writeFb(netT, modelPath.MNNModel, modelPath.benchmarkModel, modelPath.saveHalfFloat);
     }
  }

依次分析这3个过程

1. 参数解析过程

参数解析是常规操作, 没有太多需要关注的. 拿到命令行里传递的原始模型,输出模型名最终生成一个modelConfig结构体

 cxxopts::Options Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv) {
    
     cxxopts::Options options("MNNConvert"); // 借助于一个Options类分析 argc,argv里参数
      modelPath.model = modelPath.MAX_SOURCE;
     // 获得输入模型的格式
     if (result.count("framework")) {
    
         const std::string frameWork = result["framework"].as<std::string>();
         if ("TF" == frameWork) {
    
             modelPath.model = modelConfig::TENSORFLOW;
         } else if ("CAFFE" == frameWork) {
    
             modelPath.model = modelConfig::CAFFE;
         } else if ("ONNX" == frameWork) {
    
             modelPath.model = modelConfig::ONNX;
         } else if ("MNN" == frameWork) {
    
             modelPath.model = modelConfig::MNN;
         } else if ("TFLITE" == frameWork) {
    
             modelPath.model = modelConfig::TFLITE;
         } else {
    
             std::cout << "Framework Input ERROR or Not Support This Model Type Now!" << std::endl;
             std::cout << options.help({
    ""}) << std::endl;
             exit(EXIT_FAILURE);
         }
     } 
      // 获得MNN格式模型的路径
     if (result.count("MNNModel")) {
    
         const std::string MNNModelPath = result["MNNModel"].as<std::string>();
         modelPath.MNNModel             = MNNModelPath;
     } 
 }
 其中
 struct modelConfig {
    
     enum MODEL_SOURCE {
     TENSORFLOW = 0, CAFFE, ONNX, MNN, TFLITE, MAX_SOURCE }; // 可接受的输入模型格式
     // MNN model path
     std::string MNNModel;  // 输出MNN模型的文件路径名
     std::string prototxtFile;
     std::string modelFile; // 输入模型文件
     // bizCode
     std::string bizCode;
     // model source
     MODEL_SOURCE model;
     bool benchmarkModel;
     bool saveHalfFloat;
     bool forTraining = false;
 }
 ​
 // 这部分是 解析命令行, 不是核心就不多分析了
 inline void ParseResult::parse(int &argc, char **&argv) {
    
      while (current != argc) {
    
          // . . .
      }
 }

2. 模型转换过程

MNN支持caffe,tensorflow,onnx,tflite格式模型作为输入. 模型格式不同,解析读取方式也不一样, 这里以ONNX为例进行代码跟踪, 不过主体框架代码是没有区别.

开始分析这个转化过程之前,我们目标是要回答3个问题, 模型输入具体格式是什么样? 输出格式是什么样? 转化过程是怎么样?

模型输入:

以ONNX为例, 跟踪代码之前需要先了解ONNX里计算流图graph的格式,参考链接Open Neural Network Exchange, 必须搞清楚Graph,Node,Input,Output,Tensor在ONNX里的含义,才能往下分析.

这里通俗一点,写下我个人理解

Graph: 计算流图, 完整定义了神经网络算法计算过程,输入和输出. 在数据结构里,图是由点,以及连接这些点的边构成, 而神经网络计算流图是一个有向无环图, 有向才能确定数据走向, 无环才能确保计算结束,不然是死循环了.

Node: 图里的结点, 它包含一个最小运算单元,在神经网络里,叫算子(operator), 以及该算子的输入和输出

Input/Output 在onnx里是 string数组,存的是输入结点和输出接点的名称

Tensor: 神经网络算法里数据块容器, 它可以是向量,矩阵,多维矩阵.

模型输出:

MNN的输出模型是MNN::Net, 通过flatten buffer文件定义的, 因此需要先编译工程生成代码,然后导入代码文件搜索出NetT类

 struct NetT : public flatbuffers::NativeTable {
    
   typedef Net TableType;
   std::string bizCode; // 这个不知道是什么东西
   std::vector<std::unique_ptr<TensorDescribeT>> extraTensorDescribe;
   std::unique_ptr<GpuLibraryT> gpulibrary;
   std::vector<std::unique_ptr<OpT>> oplists; //所有op
   std::vector<std::string> outputName; //输出结点名
   ForwardType preferForwardType;
   NetSource sourceType;
   std::vector<std::string> tensorName; // 所有tensor的名
   int32_t tensorNumber;
   Usage usage;
 };

其中op的定义

 struct OpT : public flatbuffers::NativeTable {
    
   typedef Op TableType;
   std::vector<int32_t> inputIndexes;  // 输入
   OpParameterUnion main; // 跟op相关的数据结构, 不同的op这个结构里内容是不一样的
   std::string name; // op 的name
   std::vector<int32_t> outputIndexes; // 输出
   OpType type;  // op 类型
   MNN_DATA_FORMAT defaultDimentionFormat;  // 该处理tensor的数据排列格式,比如NCHW
   OpT()
       : type(OpType_AbsVal),
         defaultDimentionFormat(MNN_DATA_FORMAT_NHWC) {
    
   }
 };

原始的flatbuffer 文件在目录schema/default/MNN.fbs, 可以参考

转化过程:

通过分析负责转化的onnx2MNNNet函数,跟踪流程,模型转化大致分为几个步骤:

  • 读取原始ONNX模型的pb文件(在分析代码之前,需要先了解下ONNX模型格式构成), 保存到一个OnnxTmpGraph对象
  • 遍历graph的input Node,一一生成MNNOp对象, 放到oplists里(只有非initializers的input结点需要保存,因为带initializers的结点相当于是一个常量,不需要再算).
    出现过的tensor 的名都放到一个叫tensorsName的数组里;
    出现过的tensor,对应生成MNNOp实例, 保存到一个叫mnnNodesMap的 map容器里;
  • 遍历graph的所有结点Node, 一一生成MNNOp对象 , 放到oplists里有两个点要关注,1, 每一个种Op对应有OpConverter对象,生成MNNOp需要调用其run方法,把ONNX算子参数填到MNN算子里. 后面会再分析
    出现过的tensor 的名都放到一个叫tensorsName的数组里
    出现过的tensor,对应生成的MNNOp实例, 保存到一个叫mnnNodesMap的 map容器里;
  • 把所有MNNOp的 input tensor和 output tensor 都找到对应的index, 这里的index是 tensor名在tensorsName数组的偏移. 这样就把MNNOp都串起来构成了图
  • 记录所有的output tensor name
    如下是代码和注释
 // 参数分别是 输入模型文件路径,  输出NetT对象,是MNN里的graph表示
 int onnx2MNNNet(const std::string inputModel, const std::string bizCode, std::unique_ptr<MNN::NetT>& netT) {
    
     onnx::ModelProto onnxModel;
     // 把输入模型读取到 一个 ModelProto对象
     // 这里的ModelProto对象是有 proto文件定义生成的代码, 其内容是按照 ONNX的官方文档定义,原始proto文件在toolsconvertersourceonnx
     // 可以参考ONNX文档 https://github.com/onnx/onnx/blob/master/docs/IR.md
     // 同理,其他类型的输入模型,比如tensorflow 就定义在toolsconvertersourcetensorflow下
     bool success = onnx_read_proto_from_binary(inputModel.c_str(), &onnxModel);
     //...
     //  得到ONNX里的grah图
     const auto& onnxGraph = onnxModel.graph();
     const int nodeCount   = onnxGraph.node_size();
     // 用OnnxTmpGraph对象来存原始的graph, 终于脱离pb了, 后面再分析 这个graph对象
     std::shared_ptr<OnnxTmpGraph> onnxTempGraph(new OnnxTmpGraph(&onnxGraph));
     // 用一个map来存graph里结点Node
     std::map<std::string, MNN::OpT*> mnnNodesMap;
     
     // 保存所有的tensor, 解释下结点和tensor区别, 非const,或者是没有value的结点就是tensor
     //  key 是tensor的name,  value是 int代表 该tensor在图里的index
     std::map<std::string, int> tensorsName;
     //遍历input结点, 找到那些没有initializers的
     const auto& initializers         = onnxTempGraph->mInitializers;
     const auto& inputs               = onnxTempGraph->mInputs;
     const auto& outputs              = onnxTempGraph->mOutputs;
     const auto& constantNodeToDelete = onnxTempGraph->mConstantNodeToDelete;
     for (const auto& iter : inputs) {
    
         bool notHaveInitializer = initializers.find(iter.first) == initializers.end();
         if (notHaveInitializer) {
    
             netT->tensorName.push_back(iter.first);
             tensorsName.insert(std::make_pair(iter.first, tensorsName.size()));
         }
     }
     
      // 把没有initializers的inout  tensor 导入MNN网络
     for (const auto& iter : tensorsName) {
    
         // here tensorsName are true Input node name
         MNN::OpT* MNNOp  = new MNN::OpT;
         MNNOp->name      = iter.first;
         MNNOp->type      = MNN::OpType_Input;
         MNNOp->main.type = MNN::OpParameter_Input;
         auto inputParam  = new MNN::InputT;
         const auto it    = inputs.find(iter.first);
         const auto& tensorInfo = (it->second)->type().tensor_type();
         const int inputDimSize = tensorInfo.shape().dim_size();
         inputParam->dims.resize(inputDimSize);
         for (int i = 0; i < inputDimSize; ++i) {
    
             inputParam->dims[i] = tensorInfo.shape().dim(i).dim_value();
         }
         inputParam->dtype   = onnxOpConverter::convertDataType(tensorInfo.elem_type());
         inputParam->dformat = MNN::MNN_DATA_FORMAT_NCHW;
         MNNOp->outputIndexes.push_back(tensorsName[iter.first]);
         MNNOp->main.value = inputParam;
         mnnNodesMap.insert(std::make_pair(iter.first, MNNOp));
         netT->oplists.emplace_back(MNNOp);
     }
     
     // onnx node 导入到MNN node
     for (int i = 0; i < nodeCount; ++i) {
    
         const auto& onnxNode = onnxGraph.node(i);
         const auto& opType   = onnxNode.op_type();
         // name maybe null, use the first output name as node-name
         const auto& name = onnxNode.output(0);
         // TODO not to use constantNodeToDelete anymore
         if (constantNodeToDelete.find(name) != constantNodeToDelete.end()) {
    
             continue;
         }
         // 找到对应type的op需要的converter
         auto opConverter = onnxOpConverterSuit::get()->search(opType);
         MNN::OpT* MNNOp  = new MNN::OpT;
         MNNOp->name      = name;
         MNNOp->type      = opConverter->opType();
         MNNOp->main.type = opConverter->type();
         mnnNodesMap.insert(std::make_pair(name, MNNOp));
         // convert initializer to be Constant node(op),  常量
         for (int k = 0; k < onnxNode.input_size(); ++k) {
    
    
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值