在TensorFlow中模型的保存和调用,相信大家都不会陌生,使用关键语句saver = tf.train.Saver()和saver.save()就可以完成。
但是,不知道大家是否了解,tensorflow通过checkpoint这一种格式文件,是将模型的结构和权重数据分开保存的,这就造成了一些使用场景下的不方便。
所以,我们需要一种方式将模型结构和权重数据合并在一个文件中,tensorflow提供了freeze_graph函数和pb文件格式,来解决这一问题。
这些模型文件是做什么的
在save之后,模型会保存在ckpt文件中,checkpoint文件保存了一个目录下所有的模型文件列表,events文件是给可视化工具tensorboard用的。
和保存的模型直接相关的是以下这三个文件:
- .data文件保存了当前参数值
- .index文件保存了当前参数名
- .meta文件保存了当前图结构
当你使用saver.restore()载入模型时,你用的就是这一组的三个checkpoint文件。
但是,当我们需要将模型和权重整合成一个文件时,我们就需要以下的操作了。
如何使用freeze_graph生成PB文件
tensorflow提供了freeze_graph这个函数来生成pb文件。以下的代码块可以完成将checkpoint文件转换成pb文件的操作:
- 载入你的模型结构,
- 提供checkpoint文件地址
- 使用tf.train.writegraph保存图,这个图会提供给freeze_graph使用
- 使用freeze_graph生成pb文件
import
在以上的程序运行之后,./pb_model/文件夹中就会出现frozen_model.pb文件,这是我们可以使用的模型结构和权重整合过的pb文件。
freeze_graph总共有11个参数,以下逐一介绍下,供大家参考:
- input_graph:模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分。我们的例子中,使用了二进制的pb文件,对应input_binary就是False
- input_saver:Saver解析器,主要用于版本不兼容时使用。通常为空,为空时用当前版本的Saver
- input_binary:配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认值是False
- input_checkpoint:checkpoint文件地址
- output_node_names:输出节点的名字,有多个时用逗号分开,我们的输出节点是'out',这是我们使用flow = tf.cast(flow, tf.int8, 'out')将模型的输出节点命名为out。如果没有这一步的操作,我们可以找到模型的输出节点名是什么,并且在这一参数中对应。
- restore_op_name:从模型恢复节点的名字,一般使用默认:save/restore_all
- filename_tensor_name:一般使用默认:save/Const:0
- output_graph:用来保存整合后的模型输出文件,即pb文件的保存地址
- clear_devices:指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认),默认True
- initializer_nodes:默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。
- variable_names_blacklist:默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。