转载自https://www.jarvis73.cn/2018/04/25/Tensorflow-Model-Save-Read/
本文假设读者已经懂得了 Tensorflow 的一些基础概念, 如果不懂, 则移步 TF 官网 .
在 Tensorflow 中我们一般使用 tf.train.Saver()
定义的存储器对象来保存模型, 并得到形如下面列表的文件:
checkpoint
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta
其中 checkpoint
文件中记录了该储存器历史上所有保存过的模型(三件套文件)的名称, 以及最近一次保存的文件, 这里我们并不需要 checkpoint
.
Tensorflow 模型冻结是指把计算图的定义和模型权重合并到同一个文件中, 可以按照以下步骤实施:
- 恢复已保存的计算图: 把预先保存的计算图(meta graph) 载入到默认的计算图中, 并将计算图序列化.
- 加载权重: 开启一个会话(Session), 把权重载入到计算图中
- 删除推导所需以外的计算图元数据(metadata): 冻结模型之后是不需要训练的, 所以只保留推导(inference) 部分的计算图 (这部分可以通过指定模型输出来自动完成)
- 保存到硬盘: 序列化冻结的 graph_def 协议缓冲区(Protobuf) 并转储到硬盘
注意: 前两步实际上就是 Tensorflow 中的加载计算图和权重, 关键的部分就是图的冻结, 而冻结 TF 已经提供了函数.
1. 模型的保存
TF 使用 saver = tf.train.Saver()
定义一个存储器对象, 然后使用 saver.save()
函数保存模型. saver
定义时可以指定需要保存的变量列表, 最大的检查点数量, 是否保存计算图等. 官网例子如下:
v1 = tf.Variable(..., name='v1') v2 = tf.Variable(..., name='v2') # 使用字典指定要保存的变量, 此时可以为每个变量重命名(保存的名字) saver = tf.train.Saver({
'v1': v1, 'v2': v2}) # 使用列表指定要保存的变量, 变量名字不变. 以下两种保存方式等价 saver = tf.train.Saver([v1, v2]) saver = tf.train.Saver({
v.op.name: v for v in [v1, v2]}) # 保存相应变量到指定文件, 如果指定 global_step, 则实际保存的名称变为 model.ckpt-xxxx saver.save(sess, "./model.ckpt", global_step)
每保存一次, 就会产生前言所述的四个文件, 其中 checkpoint 文件会更新. 其中 saver.save()
函数的 write_meta_graph
参数默认为 True
, 即保存权重时同时保存计算图到 meta
文件.
2. 模型的读取
TF 模型的读取分为两种, 一种是我们仅读取模型变量, 即 index
文件和 data
文件; 另一种是读取计算图. 通常来说如果是我们自己保存的模型, 那么完全可以设置 saver.save()
函数的 write_meta_graph
参数为 False
以节省空间和保存的时间, 因为我们可以使用已有的代码直接重新构建计算图. 当然如果为了模型迁移到其他地方, 则最好同时保存变量和计算图.
2.1 读取计算图
2.1.1 读取计算图核心函数
从 meta
文件读取计算图使用 tf.train.import_meta_graph()
函数, 比如:
with tf.Session() as sess: new_saver = tf.train.import_meta_graph("model.ckpt.meta")
此时计算图就会加载到 sess
的默认计算图中, 这样我们就无需再次使用大量的脚本来定义计算图了. 实际上使用上面这两行代码即可完成计算图的读取. 注意可能我们获取的模型(meta文件)同时包含定义在CPU主机(host)和GPU等设