深度学习计算框架综述(三)模型格式设计

本章
摘要由CSDN通过智能技术生成

 

目录

模型定义分析

Tensorflow

Caffe

TFLite

NCNN

设计自定义模型


Converter的作用是将一种表达形式转换成另一种表达形式,需要遵守的基本原则是:这个转换过程不会丢失内部存储的信息(例如,模型的网络结构以及权重参数)。即两种表达形式应该是等价的。具体到模型转换,我们要求设计的自定义模型格式能够存储其他模型,如tensorflow模型的网络结构与权重参数。

模型定义分析

在设计模型格式之前,我们先看看目前主流的计算框架,Tensorflow、Caffe、TFlite、NCNN是如何定义模型格式的(MACE、MNN等框架的模型定义都值得研究,这里不一一分析了,大家感兴趣可以深入了解一下),推荐一个叫Netron的模型可视化工具,Netron基本支持目前所有主流计算框架的模型。

Tensorflow

对于Tensorflow 而言,和存储网络结构、权重参数信息相关的几个proto文件是graph.proto、node_def.proto、attr_value.proto(源码在这里),下面是graph.proto的内容:

syntax = "proto3";
 
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "GraphProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "tensorflow/core/framework/node_def.proto";
import "tensorflow/core/framework/function.proto";
import "tensorflow/core/framework/versions.proto";
 
 
message GraphDef {
  repeated NodeDef node = 1;
  VersionDef versions = 4;
  int32 version = 3 [deprecated = true];
  FunctionDefLibrary library = 2;
};

我们重点放在 repeated NodeDef node = 1 这一行,可以看到,graph是由多个node组合成的,其他的成员如versions、library,都可以暂时忽略,不影响我们分析,接下来,我们看看node的组成:

syntax = "proto3";
 
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "NodeProto";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "tensorflow/core/framework/attr_value.proto";
 
message NodeDef {
  string name = 1;
  string op = 2;
  repeated string input = 3;
  string device = 4;
  map<string, AttrValue> attr = 5;
  message ExperimentalDebugInfo {
    repeated string original_node_names = 1;
    repeated string original_func_names = 2;
  };
  ExperimentalDebugInfo experimental_debug_info = 6;
};

我们重点关注name、op、input、attr这4个成员,name表示node的名字,op表示node的类型,如Conv、Relu、Identity,input表示node的输入节点(可能有多个),attr以map<string, AttrValue>的形式保存了node的属性,比如padding、data_format等(如下图),其中AttrValue定义在attr_value.proto中

我们最后再看一下attr_value.proto的内容:

syntax = "proto3";
 
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "AttrValueProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
 
message AttrValue {
  
  message ListValue {
    repeated bytes s = 2;                        // "list(string)"
    repeated int64 i = 3 [packed = true];        // "list(int)"
    repeated float f = 4 [packed = true];        // "list(float)"
    repeated bool b = 5 [packed = true];         // "list(bool)"
    repeated DataType type = 6 [packed = true];  // "list(type)"
    repeated TensorShapeProto shape = 7;         // "list(shape)"
    repeated TensorProto tensor = 8;             // "list(tensor)"
    repeated NameAttrList func = 9;              // "list(attr)"
  }
 
  oneof value {
    bytes s = 2;                 // "string"
    int64 i = 3;                 // "int"
    float f = 4;                 // "float"
    bool b = 5;                  // "bool"
    DataType type = 6;           // "type"
    TensorShapeProto shape = 7;  // "shape"
    TensorProto tensor = 8;      // "tensor"
    ListValue list = 1;          // any "list(...)"

    NameAttrList func = 10;
    string placeholder = 9;
  }
}
 
message NameAttrList {
 
  string name = 1;
  map<string, AttrValue> attr = 2;
 
}

可以看到,attr_value定义了多种类型的属性值,其中tensor存放了weights、bias这类权重参数。

分析完上面这三个proto文件&

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值