目录
Converter的作用是将一种表达形式转换成另一种表达形式,需要遵守的基本原则是:这个转换过程不会丢失内部存储的信息(例如,模型的网络结构以及权重参数)。即两种表达形式应该是等价的。具体到模型转换,我们要求设计的自定义模型格式能够存储其他模型,如tensorflow模型的网络结构与权重参数。
模型定义分析
在设计模型格式之前,我们先看看目前主流的计算框架,Tensorflow、Caffe、TFlite、NCNN是如何定义模型格式的(MACE、MNN等框架的模型定义都值得研究,这里不一一分析了,大家感兴趣可以深入了解一下),推荐一个叫Netron的模型可视化工具,Netron基本支持目前所有主流计算框架的模型。
Tensorflow
对于Tensorflow 而言,和存储网络结构、权重参数信息相关的几个proto文件是graph.proto、node_def.proto、attr_value.proto(源码在这里),下面是graph.proto的内容:
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "GraphProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "tensorflow/core/framework/node_def.proto";
import "tensorflow/core/framework/function.proto";
import "tensorflow/core/framework/versions.proto";
message GraphDef {
repeated NodeDef node = 1;
VersionDef versions = 4;
int32 version = 3 [deprecated = true];
FunctionDefLibrary library = 2;
};
我们重点放在 repeated NodeDef node = 1 这一行,可以看到,graph是由多个node组合成的,其他的成员如versions、library,都可以暂时忽略,不影响我们分析,接下来,我们看看node的组成:
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "NodeProto";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "tensorflow/core/framework/attr_value.proto";
message NodeDef {
string name = 1;
string op = 2;
repeated string input = 3;
string device = 4;
map<string, AttrValue> attr = 5;
message ExperimentalDebugInfo {
repeated string original_node_names = 1;
repeated string original_func_names = 2;
};
ExperimentalDebugInfo experimental_debug_info = 6;
};
我们重点关注name、op、input、attr这4个成员,name表示node的名字,op表示node的类型,如Conv、Relu、Identity,input表示node的输入节点(可能有多个),attr以map<string, AttrValue>的形式保存了node的属性,比如padding、data_format等(如下图),其中AttrValue定义在attr_value.proto中
我们最后再看一下attr_value.proto的内容:
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "AttrValueProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
message AttrValue {
message ListValue {
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated DataType type = 6 [packed = true]; // "list(type)"
repeated TensorShapeProto shape = 7; // "list(shape)"
repeated TensorProto tensor = 8; // "list(tensor)"
repeated NameAttrList func = 9; // "list(attr)"
}
oneof value {
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
DataType type = 6; // "type"
TensorShapeProto shape = 7; // "shape"
TensorProto tensor = 8; // "tensor"
ListValue list = 1; // any "list(...)"
NameAttrList func = 10;
string placeholder = 9;
}
}
message NameAttrList {
string name = 1;
map<string, AttrValue> attr = 2;
}
可以看到,attr_value定义了多种类型的属性值,其中tensor存放了weights、bias这类权重参数。
分析完上面这三个proto文件&