5.1 Tensorflow:图与模型的加载与存储

原创 2017年08月04日 12:12:51

前言

自己学Tensorflow,现在看的书是《TensorFlow技术解析与实战》,不得不说这书前面的部分有点坑,后面的还不清楚.图与模型的加载写的不清楚,书上的代码还不能运行=- =,真是BI….咳咳.之后还是开始了查文档,翻博客的填坑之旅
,以下为学习总结.

快速应用

存储与加载,简单示例

# 一般而言我们是构建模型之后,session运行,但是这次不同之处在于我们是构件好之后存储了模型
# 然后在session中加载存储好的模型,再运行
import tensorflow as tf
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name='v1')
v2 = tf.Variable(tf.random_normal([2, 3]), name='v2')
init_op = tf.global_variables_initializer() # 初始化全部变量
# saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()
# 只存储图
if not os.path.exists('save/model.meta'):
    saver.export_meta_graph('save/model.meta')


print()
with tf.Session() as sess:
    sess.run(init_op)
    print('v1:', sess.run(v1)) # 打印v1、v2的值一会读取之后对比
    print('v2:', sess.run(v2))
    saver_path = saver.save(sess, 'save/model.ckpt')  # 将模型保存到save/model.ckpt文件
    print('Model saved in file:', saver_path)

print()
with tf.Session() as sess:
    saver.restore(sess, 'save/model.ckpt') # 即将固化到硬盘中的模型从保存路径再读取出来,这样就可以直接使用之前训练好,或者训练到某一阶段的的模型了
    print('v1:', sess.run(v1)) # 打印v1、v2的值和之前的进行对比
    print('v2:', sess.run(v2))
    print('Model Restored')

print()
# 只加载图,
saver = tf.train.import_meta_graph('save/model.ckpt.meta')
with tf.Session() as sess:
    saver.restore(sess, 'save/model.ckpt')
    # 通过张量的名称来获取张量,也可以直接运行新的张量
    print('v1:', sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))
    print('v2:', sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))

运行结果:


v1: [[-0.78213912 -0.72646964]]
v2: [[-0.36301413 -0.99892306  0.21593148]
 [-1.09692276 -0.06931346  0.19474344]]
Model saved in file: save/model.ckpt

v1: [[-0.78213912 -0.72646964]]
v2: [[-0.36301413 -0.99892306  0.21593148]
 [-1.09692276 -0.06931346  0.19474344]]
Model Restored

v1: [[-0.78213912 -0.72646964]]
v2: [[-0.36301413 -0.99892306  0.21593148]
 [-1.09692276 -0.06931346  0.19474344]]

构建模型后直接运行的结果,与加载存储的模型,加载存储的图,并哪找张量的名称获取张量并运行的结果是一致的

存储的文件

保存的文件

tf.train.Saver与存储文件的讲解

核心定义

主要类:tf.train.Saver类负责保存和还原神经网络
自动保存为三个文件:模型文件列表checkpoint,计算图结构model.ckpt.meta,每个变量的取值model.ckpt。其中前两个自动生成。
加载持久化图:通过tf.train.import_meta_graph(“save/model.ckpt.meta”)加载持久化的图

存储文件的讲解

这段代码中,通过saver.save函数将TensorFlow模型保存到了model/model.ckpt文件中,这里代码中指定路径为”save/model.ckpt”,也就是保存到了当前程序所在文件夹里面的save文件夹中。

TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在
checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState
Protocol Buffer.

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef
Protocol Buffer定义的。MetaGraphDef
中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef
信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice
Protocol
Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,请自查。

保存图与模型进阶

按迭代次数保存

# 在1000次迭代时存储
saver.save(sess, 'my_test_model',global_step=1000)

运行结果:

my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint

按时间保存

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

更详细的解释

其实更详细的解释就在源码之中,这些英语还是简单,我相信以大家的水平应该都能看得懂。就不侮辱大家的智商。

  def __init__(self,
               var_list=None,
               reshape=False,
               sharded=False,
               max_to_keep=5,
               keep_checkpoint_every_n_hours=10000.0,
               # 默认时间是一万小时,有趣
               # 但我们只争朝夕
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False):
    """Creates a `Saver`.

    The constructor adds ops to save and restore variables.

    `var_list` specifies the variables that will be saved and restored. It can
    be passed as a `dict` or a list:

    * A `dict` of names to variables: The keys are the names that will be
      used to save or restore the variables in the checkpoint files.
    * A list of variables: The variables will be keyed with their op name in
      the checkpoint files.

    For example:

    ```python
    v1 = tf.Variable(..., name='v1')
    v2 = tf.Variable(..., name='v2')

    # Pass the variables as a dict:
    saver = tf.train.Saver({'v1': v1, 'v2': v2})

    # Or pass them as a list.
    saver = tf.train.Saver([v1, v2])
    # Passing a list is equivalent to passing a dict with the variable op names
    # as keys:
    saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
    ```

    The optional `reshape` argument, if `True`, allows restoring a variable from
    a save file where the variable had a different shape, but the same number
    of elements and type.  This is useful if you have reshaped a variable and
    want to reload it from an older checkpoint.

    The optional `sharded` argument, if `True`, instructs the saver to shard
    checkpoints per device.

    Args:
      var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
        names to `SaveableObject`s. If `None`, defaults to the list of all
        saveable objects.
      reshape: If `True`, allows restoring parameters from a checkpoint
        where the variables have a different shape.
      sharded: If `True`, shard the checkpoints, one per device.
      max_to_keep: Maximum number of recent checkpoints to keep.
        Defaults to 5.
      keep_checkpoint_every_n_hours: How often to keep checkpoints.
        Defaults to 10,000 hours.
      name: String.  Optional name to use as a prefix when adding operations.
      restore_sequentially: A `Bool`, which if true, causes restore of different
        variables to happen sequentially within each device.  This can lower
        memory usage when restoring very large models.
      saver_def: Optional `SaverDef` proto to use instead of running the
        builder. This is only useful for specialty code that wants to recreate
        a `Saver` object for a previously built `Graph` that had a `Saver`.
        The `saver_def` proto should be the one returned by the
        `as_saver_def()` call of the `Saver` that was created for that `Graph`.
      builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
        Defaults to `BaseSaverBuilder()`.
      defer_build: If `True`, defer adding the save and restore ops to the
        `build()` call. In that case `build()` should be called before
        finalizing the graph or using the saver.
      allow_empty: If `False` (default) raise an error if there are no
        variables in the graph. Otherwise, construct the saver anyway and make
        it a no-op.
      write_version: controls what format to use when saving checkpoints.  It
        also affects certain filepath matching logic.  The V2 format is the
        recommended choice: it is much more optimized than V1 in terms of
        memory required and latency incurred during restore.  Regardless of
        this flag, the Saver is able to restore from both V2 and V1 checkpoints.
      pad_step_number: if True, pads the global step number in the checkpoint
        filepaths to some fixed width (8 by default).  This is turned off by
        default.
      save_relative_paths: If `True`, will write relative paths to the
        checkpoint state file. This is needed if the user wants to copy the
        checkpoint directory and reload from the copied directory.

    Raises:
      TypeError: If `var_list` is invalid.
      ValueError: If any of the keys or values in `var_list` are not unique.
    """
版权声明:转载请标明出处:http://blog.csdn.net/fontthrone

相关文章推荐

Ubuntu循环登录问题解决方案

求助!!ubuntu12.04管理员账户登录不了桌面,只能客人会话登录。 登录管理员账户时,输入密码后,一直在登录界面循环 费了好大劲啊,一上午的时间,终于搞定了,哈哈哈 ...

Python标准库之pickle包,cpickle包

1、pickle包 对于上述过程,最常用的工具是Python中的pickle包。 (1)、将内存中的对象转换成为文本流: import pickle # define class cl...

Tensorflow实战学习(四十九)【模型存储加载,队列线程,加载数据,自定义操作】

生成检查点文件(chekpoint file),扩展名.ckpt,tf.train.Saver对象调用Saver.save()生成。包含权重和其他程序定义变量,不包含图结构。另一程序使用,需要重新创建...
  • WuLex
  • WuLex
  • 2017-11-23 11:22
  • 77

tensorflow-模型保存和加载(一)

模型保存: import tensorflow as tf # save to file W = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.floa...

tensorflow 模型保存与加载

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ 什么...
  • spylyt
  • spylyt
  • 2017-05-11 10:17
  • 4231

TensorFlow使用C++加载使用训练好的模型,.cc文件代码实现的相关类及方法总结

在官网API和Tensorflow源码头文件中查看获取。 同时参考 https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-t...

tensorflow从0开始(6)——保存加载模型

目的 学习tensorflow的目的是能够训练的模型,并且利用已经训练好的模型对新数据进行预测。下文就是一个简单的保存模型加载模型的过程。 保存模型 import tenso...

tensorflow学习笔记六:保存和加载训练模型

对于机器学习,尤其是深度学习DL的算法,模型训练可能很耗时,几个小时或者几天,所以如果是测试模块出了问题,每次都要重新运行就显得很浪费时间,所以如果训练部分没有问题,那么可以直接将训练的模型保存起来,...

tensorflow之inception_v3模型的部分加载及权重的部分恢复(23)---《深度学习》

大家都知道,在加载模型及对应的权重进行训练的时候,我们可以整个使用所提供的模型,但是有时候呢?所提供的模型不能很好的满足我们的要求,有时候我们只需要模型的前几层然后进行对应的权重赋值,这时候,我们应该...

Tensorflow学习笔记-模型保存与加载

使用Tensorflow训练好模型之后,我们需要将训练好的模型保存起来,方便以后的使用,这就是Tensorflow模型的持久化。保存v1 = tf.Variable(tf.constant(1,sha...
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:深度学习:神经网络中的前向传播和反向传播算法推导
举报原因:
原因补充:

(最多只允许输入30个字)