python模型持久化_【腾讯机智】Tensorflow 模型持久化与静态分析

在机器学习工程中,模型持久化是一个不可或缺的操作,无论是训练工程中用于容灾的断点续训 (re-training)、模型调优中的 fine-tuning,还是模型训练完成后部署到线上推理环境,都需要对模型进行序列化和反序列化。

本文聚焦于 Tensorflow 中的模型持久化,从表示图的基本数据结构出发,尝试梳理了几种常见的模型持久化方式、内部实现及其异同,最后对几种常用的模型静态分析方式进行了介绍。文章篇幅较长,但争取不做文档的搬运工。

1. 基本数据结构

我们知道,计算图 (graph) 是 Tensorflow 用于表达计算任务的一个核心概念。开发者在前端 (Python, Java, etc.) 使用 TF API 构建神经网络的结构,而 Tensorflow 会将前端代码所描绘的图转换(即“序列化”)成 Protocol Buffer 对象,以 session 作为桥梁传到后端后,最后再执行真正的 Protocol Buffer 对象中所定义的计算过程。也就是说,计算图 (graph) 的基本数据结构都是由 Protocol Buffer 来定义的,不同的模型持久化方式本质上都只是针对这些 pb 结构的序列化。因此,我们首先来看看一个计算图 (graph) 是由哪些 pb 结构所定义的。

1.1. MetaGraphDef

MetaGraphDef 是用于表示一个计算图的高级数据结构,所谓“高级”,是因为其包含了计算图中几乎所有的静态结构信息和元数据信息,借助基于 MetaGraphDef 导出的模型(以及额外的数据文件),我们可以在已有模型的基础上实现断点续训和推理计算等操作。

//tensorflow/core/protobuf/meta_graph.proto

message MetaGraphDef {

MeatInfoDef meta_info_def = 1;

GraphDef graph_def = 2;

SaverDef saver_def = 3;

map collection_def = 4;

map signature_def = 5;

repeated AssetFileDef asset_file_def = 6;

}

从以上 MetaGraphDef 的定义中,我们可以看到 MetaGraphDef 包含了 6 种类型的字段,但由于篇幅限制,这里只关注与模型持久化过程中最密切相关的 pb 结构定义,完整的定义可以查看在注释中标明的 proto 源文件。

1.2. MetaInfoDefMeta information regarding the graph to be exported. To be used by users of this protocol buffer to encode information regarding their meta graph.

MetaInfoDef 类型的字段包含了计算图中的所有元信息。e.g. 图的版本信息、模型已迭代步数、计算图所用到的 Op 信息等:

//tensorflow/core/protobuf/meta_graph.proto

message MetaInfoDef {

// 计算图的版本号

string meta_graph_version = 1;

// 该 list 记录了计算图中使用到的所有 Op 的信息(不包括没有使用到的 Op);

// 该函数只记录 Op 信息,不记录 Op 执行的次数

OpList stripped_op_list = 2;

google.protobuf.Any any_info = 3;

// 用户指定的标签

repeated string tags = 4;

string tensorflow_version = 5;

string tensorflow_git_version = 6;

bool stripped_default_attrs = 7;

}

其中,OpList 类型的 stripped_op_list 字段实际上是一个 OpDef 对象的 list,我们来看看 OpDef 的定义描述:A NodeDef in a GraphDef specifies an Op by using the "op" field which should match the name of a OpDef.

可以理解为,OpDef 相当于 class,而我们下文将会讲到的 NodeDef 就相当于对应的 instance。

//tensorflow/core/framework/op_def.proto

message OpDef {

string name = 1; // 定义了运算的名称(唯一标识符)

repeated ArgDef input_arg = 2; // 定义了输入列表

repeated ArgDef output_arg = 3; // 定义了输出列表

repeated AttrDef attr = 4; // 定义了 Op 的其它参数信息

string summary = 5;

string description = 6;

OpDeprecation deprecation = 8;

bool is_commutative = 18;

bool is_aggregate = 16

bool is_stateful = 17;

bool allows_uninitialized_input = 19;

};

e.g. Add Op 的 OpDef 定义:

op {

name: "Add"

input_arg{

name: "x"

type_attr:"T"

}

input_arg{

name: "y"

type_attr:"T"

}

output_arg{

name: "z"

type_attr:"T"

}

attr{

name:"T"

type:"type"

allow_values{

list{

type:DT_HALF

type:DT_FLOAT

...

}

}

}

}

上面给出的例子是名称为 Add 的 OpDef。这个 Op 的输入有两个,输出有一个,输入输出属性均指定了类型属性 typr_attr,并且这个属性的值为 T。

1.3. GraphDef

GraphDef 是 MetaGraphDef 的核心内容,其主要包含了一个 NodeDef 类型的列表。由于 MetaInfoDef 中已经包含所有运算的具体信息,所以 GraphDef 主要关注的是 Node 的连接结构:

//tensorflow/core/framework/graph.proto

message GraphDef {

// GraphDef 的主要信息存储在 node 属性中,它记录了计算图上所有的节点信息

repeated NodeDef node = 1;

// Tensorflow 版本号

VersionDef versions = 4;

}

其中,NodeDef 类型的字段对应了 python graph 中的每个 operation:

//tensorflow/core/framework/node_def.proto

message NodeDef {

// name 是一个 node 的唯一标识符,可以通过 node 的名称来获得相应的 node

string name = 1;

// op 属性给出了该 node 使用的 op 的名称

// 通过这个名称可以在 meta_info_def 中找到该 op 的具体信息

string op = 2;

// input 属性是一个字符串列表,定义了 node 的输入

repeated string input = 3;

// 执行这个 op 的设备,为空时表示自动选择

string device = 4;

// 和当前 op 有关的配置信息

map attr = 5;

}

NOTE: 在 NodeDef 中没有关于输出 (output) 的定义,实际上,Tensorflow 中都是通过 Input 来定义 Node 之间的连接信息的: - 通过 input 定义输入的节点名称。 - 通过 attr (attribute) 来定义输出 tensor 的数据类型和 shape。

1.4. CollectionDefCollectionDef map that further describes additional components of the model, such as Variables, QueueRunners, etc.

CollectionDef 类型的字段保存了计算图中需要特殊的标注以方便 import_meta_graph 后取回的节点集合(如:"train_op"、"prediction" 等):

//tensorflow/core/protobuf/meta_graph.proto

message CollectionDef {

// 节点集合

message Nodelist {

repeated string value = 1;

}

// 字符串或者序列化之后的 Procotol Buffer 集合

message BytesList {

repeated bytes value = 1 ;

}

// 整数集合

message Int64List {

repeated int64 value = 1[packed = true];

}

// 实数集合

message FloatList {

repeated float value = 1[packed = true] ;

}

message AnyList {

repeated google.protobuf.Any value= 1;

}

// 可以维护 4 类不同的集合

oneof kind {

NodeList node_list = 1;

BytesList bytes_lista = 2;

Int64List int64_list = 3;

Floatlist float_list = 4;

AnyList any_list = 5;

}

}

2. 持久化方式

了解了 Tensorflow 中图的数据结构之后,我们再来看看 Tensorflow 针对以上这些 pb 结构都提供了哪些持久化方式以及相应的 API 呢?

2.1. GraphDef

2.1.1. 导出:tf.io.write_graph

tf.io.write_graph 方法提供了导出 GraphDef 的功能接口:

# tensorflow/python/framework/graph_io.py

@tf_export('io.write_graph', v1=['io.write_graph', 'train.write_graph'])

def write_graph(graph_or_graph_def, logdir, name, as_text=True):

通过查看该方法的实现,我们可以发现核心逻辑其实是调用了 pb 对象的 SerializeToString 方法进行序列化,再写入到文件中:

if as_text:

file_io.atomic_write_string_to_file(path,

text_format.MessageToString(graph_def))

else:

file_io.atomic_write_string_to_file(path, graph_def.SerializeToString())

e.g.

import tensorflow as tf

v = tf.Variable(0, name='my_variable')

sess = tf.Session()

tf.io.write_graph(sess.graph_def, '/data/my-model', 'train.json')

NOTE: tf.io.write_graph 方法提供了一个 as_text 参数,通过该参数可以设置执行 proto 的序列化时是否按照 ASCII 编码进行。由于 as_text 默认为 true,因此可直接输出为 pbtxt/json(可读文本)格式;若将 as_text 设置为 false,则输出为 pb(二进制)格式。

2.1.2. 导入:tf.graph_util.import_graph_def

tf.graph_util.import_graph_def() 方法能将 GraphDef 对象导入到当前 session 的 default graph 中,但使用该方法需要手动完成反序列化的过程:

# tensorflow/python/framework/importer.py

@tf_export('graph_util.import_graph_def', 'import_graph_def')

def import_graph_def(graph_def,

input_map=None,

return_elements=None,

name=None,

op_dict=None,

producer_op_list=None)

e.g. 从 GraphDef 的序列化 .pb 文件中构建计算图:

model_path = 'graph_def.pb'

with tf.gfile.GFile(model_path, "rb") as f:

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

tf.import_graph_def(graph_def, name='XXX')

graph = tf.get_default_graph()

print(g.get_tensor_by_name('XXX/image_tensor:0'))

NOTE: 从 GraphDef 中恢复构建的图可以被训练吗?

不可以。在 GraphDef 中只有网络节点的连接信息(静态结构),而没有 MetaInfoDef 中的 Op 信息和 CollectionDef 中的 train_op 集合。因此,我们无法直接从 GraphDef 来构建图并恢复训练。

e.g.

with tf.Graph().as_default() as graph:

tf.import_graph_def("graph_def_path")

with tf.Session() as sess:

# 由于只有 graph_def,而没有 collection_def,所以这里只会返回一个空的 list

tf.trainable_variables()

2.2. FrozenGraphDef

在实际的线上 inference 中,我们通常不会直接使用结构与权值分离的持久化模型文件,而是使用 Tensorflow 提供的一种模型固化方案:GraphDef 虽然不能保存 variable 的权值,但可以保存 Constant。通过将 variables 直接转为 tf.constant 存储在 GraphDef 中的 NodeDef 里,就可以将整个 Tensorflow 计算图存放在一个文件中,固化生成的 GraphDef 称为 FrozenGraphDef。

2.2.1. API: graph_util.convert_variables_to_constants

Tensorflow 为模型固化提供了 graph_util.convert_variables_to_constants 方法,该方法将指定 session 中的 variables 权值转化为 const ops 并替换 GraphDef 中的相应 variables 节点:

# tensorflow/python/framework/graph_util_impl.py

@tf_export(v1=["graph_util.convert_variables_to_constants"])

def convert_variables_to_constants(sess,

input_graph_def,

output_node_names,

variable_names_whitelist=None,

variable_names_blacklist=None):

2.2.2. Tool: tensorflow/python/tools/freeze_graph.py

从 1.3.0 版本开始,Tensorflow 在 tensorflow/python/tools 目录下提供了转换脚本 freeze_graph.py 来将图中的 variable 固化 (frozen) 成 constant 存储在 GraphDef 里。

e.g.

$ python $TENSORFLOW_DIR/python/tools/freeze_graph \

--input_graph=some_graph_def.pb \

--input_checkpoint=model.ckpt-8361242 \

--output_graph=/tmp/frozen_graph.pb \

--output_node_names=softmax

从源码中可以看到,这个脚本实际上也是调用了 graph_util.convert_variables_to_constants 进行转换:

# tensorflow/python/tools/freeze_graph.py

if input_meta_graph_def:

output_graph_def = graph_util.convert_variables_to_constants(

sess,

input_meta_graph_def.graph_def,

output_node_names.replace(" ", "").split(","),

variable_names_whitelist=variable_names_whitelist,

variable_names_blacklist=variable_names_blacklist)

else:

output_graph_def = graph_util.convert_variables_to_constants(

sess,

input_graph_def,

output_node_names.replace(" ", "").split(","),

variable_names_whitelist=variable_names_whitelist,

variable_names_blacklist=variable_names_blacklist)

# Write GraphDef to file if output path has been given.

if output_graph:

with gfile.GFile(output_graph, "wb") as f:

f.write(output_graph_def.SerializeToString())

2.3. MetaGraphDef

2.3.1. 导出:tf.train.export_meta_graph

# tensorflow/python/training/saver.py

@tf_export(v1=["train.export_meta_graph"])

def export_meta_graph(filename=None,

meta_info_def=None,

graph_def=None,

saver_def=None,

collection_list=None,

as_text=False,

graph=None,

export_scope=None,

clear_devices=False,

clear_extraneous_savers=False,

strip_default_attrs=False,

**kwargs):

NOTE: 当 GraphDef 大小超过 2GB 时,该方法会抛出 ValueError。

我们来看看这个方法内部是如何导出 meta_graph 的呢?

# tensorflow/python/training/saver.py

meta_graph_def, _ = meta_graph.export_scoped_meta_graph(

filename=filename,

meta_info_def=meta_info_def,

graph_def=graph_def,

saver_def=saver_def,

collection_list=collection_list,

as_text=as_text,

graph=graph,

export_scope=export_scope,

clear_devices=clear_devices,

clear_extraneous_savers=clear_extraneous_savers,

strip_default_attrs=strip_default_attrs,

**kwargs)

# tensorflow/python/framework/meta_graph.py

scoped_meta_graph_def = create_meta_graph_def(

graph_def=graph_def,

graph=graph,

export_scope=export_scope,

exclude_nodes=exclude_nodes,

clear_extraneous_savers=clear_extraneous_savers,

saver_def=saver_def,

strip_default_attrs=strip_default_attrs,

**kwargs)

if filename:

graph_io.write_graph(

scoped_meta_graph_def,

os.path.dirname(filename),

os.path.basename(filename),

as_text=as_text)

# tensorflow/python/framework/graph_io.py

if as_text:

file_io.atomic_write_string_to_file(path,

text_format.MessageToString(graph_def))

else:

file_io.atomic_write_string_to_file(path, graph_def.SerializeToString())

从源码中可以一马平川的看到,该方法根据 MetaGraphDef 的定义将当前的 graph 的各种信息一一转换为 pb 对象并填充到了 MetaGraphDef 对象中,最后调用 SerializeToString 序列化后写入到文件中。

2.3.2. 导入:tf.train.import_meta_graphThe function then adds all the nodes from the graph_def field to the current graph, recreates all the collections, and returns a saver constructed from the saver_def field.

使用 tf.train.import_meta_graph 方法,我们可以从一个 meta_graph 对象中加载所有的 nodes 定义、collections 定义:

# tensorflow/python/training/saver.py

@tf_export(v1=["train.import_meta_graph"])

def import_meta_graph(meta_graph_or_file, clear_devices=False,

import_scope=None, **kwargs):

e.g. 从 MetaGraphDef 中恢复图结构,并从 checkpoint 中恢复训练好的网络权值:

with tf.Session() as sess:

# 1. 从 `.meta` 文件导入原始网络结构图到当前 session 中:

new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')

# 2. 加载了网络结构图之后还需要加载变量数据,使用 `restore()` 方法恢复模型的变量参数

new_saver.restore(sess, tf.train.latest_checkpoint('./'))

# 3. 在此之后,模型中变量 w1 和 w2 的取值已经恢复

print(sess.run('w1:0'))

NOTE: 只从 MetaGraphDef 中恢复图结构而不加载 variables 权值可以 re-training 吗?

是可以的。Meta Graph 中包含了 variable 的节点信息,但没有 Variable 的实际取值,所以从 Meta Graph 中恢复的图,其训练实际上是从随机初始化的值开始的。训练中 Variable 的实际值都保存在 checkpoint 中,如果要从之前训练的状态继续恢复训练,就要从 checkpoint 中进行恢复。

2.4. Checkpoint

2.4.1. Checkpoint 格式定义

Checkpoint 格式的持久化模型全面保存了训练中某时间截面的模型信息(包括参数权值、超参数、梯度等)。

虽然生成 checkpoint 时只会指定一个文件路径,但实际上将会生成结构与权重分离的 4 个文件。e.g.

---checkpoint

---model.ckpt-240000.data-00000-of-00001

---model.ckpt-240000.index

---model.ckpt-240000.metacheckpoint: 该文件以 CheckpointState (python/training/checkpoint_state.proto) 格式保存了所有 checkpoint 文件的列表:

message CheckpointState { // 保存了最新模型文件的文件名 string model_checkpoint_path = 1; // 列出当前还未被删除的所有模型文件的文件名 repeated string all_model_checkpoint_paths = 2; double last_preserved_timestamp = 4;}model.ckpt-xxx.meta: 该文件保存了完整的静态图结构,实际上就是 MetaGraphDef 的二进制序列化文件(可通过 tf.MetaGraphDef.ParseFromString 转换为文本序列化格式后直接 print)。

model.ckpt-xxx.data-xxx-of-xxx: 该文件为 SSTable (Sorted String Table, a file of key/value string pairs, sorted by keys) 格式存储的数据文件,保存了所有变量的最新取值,即各个网络节点的权值。

model.ckpt-xxx.index: 该文件是参数索引文件,为数据文件提供索引,保存着参数的基本信息,但不保存参数的值。

P.S.在 Tensorflow 0.11 版本之前,variables 信息存储在一个 .ckpt 文件中;从 Tensorflow 0.11 版本开始,variables 信息通过 .index 和 .data 2 个文件来存储。

相比下文介绍的 Tensorflow Serving 标准格式 SavedModel,Checkpoint 通常更适用于 re-training。

2.4.2. tf.train.Saver()

Tensorflow 提供了一个较为高级的 API 来保存和还原一个 checkpoint 格式的模型:

# tensorflow/python/training/saver.py

@tf_export(v1=["train.Saver"])

class Saver(object):

关于 API 的细节以及具体如何使用,网上都可以找到非常详细的介绍,这里就不再赘述了,我们主要关注一下模型导出和导入的核心 API 内部的执行流程。

2.4.2.1. meta_graph && variables 导出:save

tf.train.Saver() 的 save 方法提供了以 checkpoint 格式来保存当前 session 中的默认图的 variables 和 meta_graph 的能力:

def save(self,

sess,

save_path,

global_step=None,

latest_filename=None,

meta_graph_suffix="meta",

write_meta_graph=True,

write_state=True,

strip_default_attrs=False):

其中我们可以关注一下 write_meta_graph 参数:在训练的时候,假设每 1000 次就保存一次模型,但是这些保存的文件中变化的仅仅是模型的 variables,而模型结构没有变化,也就没必要重复保存 .meta 文件,所以这种情况下我们就可以设置让网络结构不重复保存。

从 save 方法 (tensorflow/python/training/saver.py) 的实现中可以发现,该方法的执行主要包括了 3 个过程:运行在构造方法时添加的 save op,以 SSTable 格式保存参数 var_list 中指定的 variables 到 .index && .data 文件中:

model_checkpoint_path = sess.run(

self.saver_def.save_tensor_name,

{self.saver_def.filename_tensor_name: checkpoint_file})

这里执行的 `self.saver_def.save_tensor_name` 从哪来呢?我们回到类构造方法中看一下:

# 在 self._builder._build_internal 方法中,

# 通过 validate_and_slice_inputs 方法根据传入的 var_list,找到指定的变量

# (names_to_saveables 就是 self._var_list)

saveables = saveable_object_util.validate_and_slice_inputs(

names_to_saveables)

# 在 self._builder._build_internal 方法中,获取 variables 后,

# 调用 _AddSaveOps 向 graph 中添加 Save Ops

if build_save:

save_tensor = self._AddSaveOps(filename_tensor, saveables)

# 继续跟进 self._AddSaveOps 后会发现,

# 被添加的 Save Ops 就是 io_ops.save_v2

if self._write_version == saver_pb2.SaverDef.V1:

return io_ops._save(

filename=filename_tensor,

tensor_names=tensor_names,

tensors=tensors,

tensor_slices=tensor_slices)

elif self._write_version == saver_pb2.SaverDef.V2:

return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,

tensors)

最后回到 `self.saver_def.save_tensor_name`,实际上就是在构造方法中添加的 `io_ops.save_v2` op:

# python/ops/io_ops.py

return gen_io_ops.save(filename, tensor_names, tensors, name=name)

// core/kernels/save_op.cc

void Compute(OpKernelContext* context) override {

SaveTensors(context, &checkpoint::CreateTableTensorSliceBuilder, false);

}

// core/kernels/save_restore_tensor.cc

writer.Add(name, shape, slice, input.flat().data());

2. 根据 CheckpointState pb 的定义,保存 / 更新 checkpoint 文件:

# 若 write_state 参数被设置为 False,则不保存 checkpoint 文件

if write_state:

self._RecordLastCheckpoint(model_checkpoint_path)

checkpoint_management.update_checkpoint_state_internal(

save_dir=save_path_parent,

model_checkpoint_path=model_checkpoint_path,

all_model_checkpoint_paths=self.last_checkpoints,

latest_filename=latest_filename,

save_relative_paths=self._save_relative_paths)

self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)

3. 调用 export_meta_graph 方法,保存 meta_graph 到 .meta 文件中:

# 若 write_meta_graph 参数被设置为 False,则不保存 meta_graph 文件

if write_meta_graph:

meta_graph_filename = checkpoint_management.meta_graph_filename(

checkpoint_file, meta_graph_suffix=meta_graph_suffix)

if not context.executing_eagerly():

with sess.graph.as_default():

self.export_meta_graph(

meta_graph_filename, strip_default_attrs=strip_default_attrs)

继续跟进我们就会看到一个熟悉的身影,是的,`Saver` 类中的 `export_meta_graph` 方法实际上直接调用了 `tf.train.export_meta_graph` 来实现 meta_graph 的导出逻辑。

2.4.2.2. variables 恢复:restore

tf.train.Saver() 的 restore 方法提供了将 checkpoint 中的 variables 权值恢复到当前 session 的默认图中的能力:

def restore(self, sess, save_path):

同样我们来看一下 restore 方法的主要执行流程:

# python/training/saver.py

sess.run(self.saver_def.restore_op_name,

{self.saver_def.filename_tensor_name: save_path})

注意到 self.saver_def.restore_op_name,这同样是在构造方法中添加的 restore op:

saveables = saveable_object_util.validate_and_slice_inputs(

names_to_saveables)

if build_restore:

restore_op = self._AddRestoreOps(filename_tensor, saveables,

restore_sequentially, reshape)

其中,saveable_object_util.validate_and_slice_inputs 方法会根据在构造方法中传入的 var_list 参数,在 graph 中找到指定 name 的 variable 节点,并从 .index 和 .data 文件中读取变量权值:

// core/kernels/restore_op.ccvoid Compute(OpKernelContext* context) override {

RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,

preferred_shard_, false, 0);

}

// core/kernels/save_restore_tensor.ccvoid RestoreTensor(OpKernelContext* context,

checkpoint::TensorSliceReader::OpenTableFunction open_func,

int preferred_shard, bool restore_slice, int restore_index)

NOTE:对于将要进行 restore 的变量不需要再进行初始化,因为 restore 操作本身就相当于一个变量初始化的操作。

save() 方法既能在 checkpoint 中保存图的静态结构 MetaGraphDef,也能保存网络中 Variables 的当前权值;而 restore() 方法只能恢复 checkpoint 中保存的网络中 variables 的权值,而不能恢复图的静态结构。如果要从 MetaGraphDef 恢复图结构,需要另外使用 tf.train.import_meta_graph()。这是其实为了方便用户,因为大多数时候我们不需要从 MetaGraph 恢复图,而是在 python 中构建模型结构,再从 checkpoint 中恢复对应的 variables 权值。

如果图结构不同,可以直接 restore 吗?

不可以。保存模型的变量依赖于计算图上的节点,但执行 restore 时我们仅加载了模型的变量,并没有加载模型的计算图,所以如果我们想要正确的加载模型,就需要先定义一个相同的计算图结构。

2.5. ExportModel (Deprecated)

最早 TensorFlow Serving 推荐使用通过 Exportor 接口导出的模型,但现在这个接口即将被废弃,官方推荐使用最新的 SavedModel 格式

2.6. SavedModel

SavedModel (tensorflow/python/saved_model/) 是当前 Tensorflow Serving 的标准格式,可恢复且对外界不透明,它允许高层系统和工具来生成、消费、和转换 Tensorflow 模型。TensorFlow 和 Keras 模型都推荐使用这种模型格式。

SavedModel 使用 tf.saved_model 接口进行导出,是 GraphDef 和 CheckPoint 的结合体,包含了模型 Graph 结构和 variables 权值,从 SavedModel 中可以提取 GraphDef 和 Checkpoint 对象。

2.6.1. 格式特性tag

一个 SavedModel 可以包含多个不同的 MetaGraphDef,这个特性允许我们为不同的任务订制不同的计算图(不同的 MetaGraphDef 通过 tag 进行区分,如:"training", "inference" 和 "mobile"),但多个计算图共享 variables and assets,内存使用效率更高。signature_def

在部署 Tensorflow 模型时,我们需要为模型指定输入输出张量的名称。如果之前没有协商好,这个需求逼着我们得在整张计算图中寻找相应的张量,非常繁琐。因此,SavedModel 提供了 SignatureDefs,简化了这一过程。SignatureDefs 定义了一组 TensorFlow 支持的计算签名,便于在计算图中找到适合的输入输出张量。简单的说,使用这些计算签名,可以准确指定特定的输入输出节点。

2.6.2. 存储结构

模型的存储结构如下:

└── model

···├── saved_model.pb/pbtxt

···└── variables/

·········├── variables.data-*****-of-*****

·········└── variables.index

···├── assets/

···├── assets.extra/saved_model.pb/pbtxt:

//tensorflow/core/protobuf/saved_model.proto

message SavedModel {

int64 saved_model_schema_version = 1;

repeated MetaGraphDef meta_graphs = 2;

}

从以上定义可以看出,`saved_model.pb` 实际上不是标准的 `MetaGraphDef`,而是一个更高级的封装,可以包含多个 `MetaGraphDef`。variables 目录(.index 文件 + .data): 保存训练所得的权重。

P.S. 如果从 FrozenGraphDef 的 pb 文件中转出来的模型,variables 目录为空,因为 pb 文件里面的各项参数都是 tf.constant,所以不会存储到 variable 里面。assets 目录:可能需要的外部文件。

assets.extra 目录:可以添加其特定 assets。

2.6.3. 导出:tf.saved_model.builder.SavedModelBuilder

对于 SavedModel 的导出,Tensorflow 提供了 tf.saved_model.builder.SavedModelBuilder 对象用于相关操作:

# python/saved_model/builder_impl.py

@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"])

class SavedModelBuilder(_SavedModelBuilder):

e.g.

...

builder = tf.saved_model.Builder(export_dir)

with tf.Session(graph=tf.Graph()) as sess:

...

builder.add_meta_graph_and_variables(sess, ["foo-tag"],

signature_def_map=foo_signatures, assets_list=foo_assets)

...

with tf.Session(graph=tf.Graph()) as sess:

...

builder.add_meta_graph(["bar-tag", "baz-tag"])

...

builder.save()

同样的,具体的使用方法可以参考相关文档,这里不再赘述了。我们关注到:SavedModelBuilder 的模型持久化流程实际上都是对 tf.train.Saver() 的调用封装:add_meta_graph_and_variables 方法中:保存 variables 调用的是 saver.save(sess, variables_path, write_meta_graph=False, write_state=False)

保存 MetaGraphDef 调用的是 saver.export_meta_graph(clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)。

add_meta_graph 方法中:保存 MetaGraphDef 调用的是 saver.export_meta_graph(clear_devices=clear_devices, strip_default_attrs=True)。

2.6.4. 导入:tf.saved_model.loader.load

Tensorflow 在 tf.saved_model 模块下同样提供了 SavedModel 导入的 API:

# python/saved_model/loader_impl.py

@tf_export(v1=["saved_model.load", "saved_model.loader.load"])

def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):

e.g.

import tensorflow as tf

with tf.Session() as sess:

tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], "/data/mnist/saved-model")

print(sess.run('layer2/biases/Variable:0'))

从源码中我们可以看到,tf.saved_model.loader.load 方法的执行流程主要包含 2 步:

1. 根据指定的 tag 在多个 meta_graph 中找到对应的 meta_graph,进行加载并创建 saver 对象:

# tensorflow/python/saved_model/loader_impl.py

for meta_graph_def in self._saved_model.meta_graphs:

if set(meta_graph_def.meta_info_def.tags) == set(tags):

meta_graph_def_to_load = meta_graph_def

found_match = True

break

tf_saver._import_meta_graph_with_return_elements(

meta_graph_def, import_scope=import_scope, **saver_kwargs)

2. restore variables 权值到当前 session 中:

with sess.graph.as_default():

if (saver is None and

not variables._all_saveable_objects(scope=import_scope)):

tf_logging.info("The specified SavedModel has no variables; no "

"checkpoints were restored.")

elif isinstance(saver, tf_saver.Saver):

saver.restore(sess, self._variables_path)

可以看到,这里的主要逻辑同样是对 `saver` 类中的方法的调用封装。

NOTE:

不同于 tf.train.Saver() 的 restore() 方法,tf.saved_model.loader.load 既会恢复 tag 指定的图结构 (MetaGraphDef),也会恢复参数变量的训练权值。

3. 模型静态分析

理想情况下,模型发布者会编写出完备的文档,给出示例代码,供其它人参考使用。但在很多实际情况下,我们只是拿到了训练好的模型,没有相关的代码更没有齐全的文档,这个时候我们能否对手头的模型文件进行一定分析,从中获得到一些有用的信息呢?

所谓“静态分析”,指的就是在模型还未能运行起来的情况下,从模型的结构、variables 等方面切入获取相关的信息。以下简要列举了几种对模型进行静态分析的方式。

P.S. 对于 Tensorflow 模型的动态分析(性能分析 / profiling),可以参考我们团队同学写的另一篇文章。

3.1. 直接打印

对于 Tensorflow 的持久化模型,图的基本结构都使用 pb 进行表示,将图的静态结构进行持久化实际上核心逻辑就是 pb 的序列化。因此,我们可以直接打印 pb 文件中的节点信息,来获取模型的静态结构定义。

3.1.1. 二进制格式

对于二进制的模型文件 (e.g. *.pb / *.meta),可读取文件后通过 protobuf 类对象的 ParseFromString 方法加载为 protobuf instance 后进行打印:

import tensorflow as tf

from tensorflow.python.platform import gfile

filename = 'xxx.pb' # 'xxx.meta'

with gfile.FastGFile(filename,'rb') as f:

graph_def = tf.GraphDef() # 对于 .meta 文件,使用 'tf.MetaGraphDef()'

graph_def.ParseFromString(f.read())

for n in graph_def.node:

print("Name of the node -%s" % n.name)

其中,对于不同的 pb 类型,要使用相应的 protobuf 类对象:`GraphDef` 类型:使用 `tf.GraphDef()` (Export protos in `tensorflow/python/__init__.py`)

`MetaGraphDef` 类型:使用 `tf.MetaGraphDef()` (Export protos in `tensorflow/python/__init__.py`)

其它 Tensorflow Proto 类型:

# 以 SavedModel 为例:

import tensorflow as tf

from tensorflow.python.platform import gfile

from tensorflow.python.util import compat

# tensorflow/core/protobuf 目录下的 xxx.proto 文件在编译后都会生成相应的 xxx_pb2.py

from tensorflow.core.protobuf import saved_model_pb2

filename = 'saved_model.pb' # 'xxx.meta'

with gfile.FastGFile(filename,'rb') as f:

graph_def = saved_model_pb2.SavedModel()

graph_def.ParseFromString(compat.as_bytes(f.read()))

for n in graph_def.meta_graphs:

print(n)

3.1.2. 文本格式

对于二进制的模型文件 (e.g. *.pbtxt / *.json):

import tensorflow as tf

from tensorflow.python.platform import gfile

from google.protobuf import text_format

filename = 'xxx.pbtxt'

with tf.gfile.FastGFile(filename, 'r') as f:

graph_def = tf.GraphDef()

text_format.Merge(f.read(), graph_def)

for n in graph_def.node:

print("Name of the node -%s" % n.name)

3.2. Tensorflow Tools

在 Tensorflow 源代码的 tensorflow/python/tools 目录下,实际上也提供了可以对模型文件进行快速解析的工具脚本。

3.2.1. saved_model_cli

对于 SavedModel 格式的持久化模型,Tensorflow 提供了可用于静态分析的脚本工具 saved_model_cli (tensorflow/python/tools/saved_model_cli.py)。

saved_model_cli 支持使用多个子命令:show / run / scan / convert,我们这里主要关注其中的 show 命令,通过该命令可以快速输出 SavedModel 中的 meta_graph 信息:

e.g. 输出 SavedModel 中的所有 meta_graph 信息:

# 打印 SavedModel 中的可用信息

$ python $TENSORFLOW_DIR/tensorflow/python/tools/saved_model_cli.py show --dir /data/saved_model --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['user_signature']:

The given SavedModel SignatureDef contains the following input(s):

inputs['inputs'] tensor_info:

dtype: DT_STRING

shape: unknown_rank

name: tf_example:0

The given SavedModel SignatureDef contains the following output(s):

outputs['scores'] tensor_info:

dtype: DT_FLOAT

shape: (-1, 32)

name: truediv:0

Method name is: tensorflow/serving/classify

3.2.2. inspect_checkpoint

对于 checkpoint 格式保存的持久化模型,Tensorflow 同样提供了可用于静态分析的脚本工具 inspect_checkpoint (python/tools/inspect_checkpoint.py)。

e.g.

$ ls

checkpoint model.ckpt-0.data-00000-of-00001 model.ckpt-0.index model.ckpt-0.meta

$ python $TENSORFLOW_DIR/tensorflow/python/tools/inspect_checkpoint.py --file_name=model.ckpt-0

wuid_redident_location_province/1/Adam_1 (DT_FLOAT) [2000,10]

wuid_redident_location_province/2 (DT_FLOAT) [2000,10]

wuid_redident_location_province/2/Adam (DT_FLOAT) [2000,10]

...

3.3. Tensorboard

通过 Tensorboard 可以很直观的在 GUI 上观察模型的节点信息,只要拿到的模型文件中包含 GraphDef,就可以 import_graph_def 后生成 summary info,并在 Tensorboard 上加载后进行分析。

e.g.

import sys

import tensorflow as tf

from tensorflow.core.protobuf import saved_model_pb2

from tensorflow.python.platform import gfile

from tensorflow.python.util import compat

with tf.Session() as sess:

model_filename ='./model/saved_model.pb'

with gfile.FastGFile(model_filename, 'rb') as f:

data = compat.as_bytes(f.read())

sm = saved_model_pb2.SavedModel()

sm.ParseFromString(data)

assert 1 == len(sm.meta_graphs),

'More than one graph found. Not sure which to write!'

g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)

LOGDIR='./log'

train_writer = tf.summary.FileWriter(LOGDIR)

train_writer.add_graph(sess.graph)

train_writer.flush()

train_writer.close()

启动 TensorBoard:

$ tensorboard --logdir ./log

4. Refer Links

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值