关闭

5.1 Tensorflow:图与模型的加载与存储

标签: sessionTensorflow模型加载与存储
607人阅读 评论(2) 收藏 举报
分类:

前言

自己学Tensorflow,现在看的书是《TensorFlow技术解析与实战》,不得不说这书前面的部分有点坑,后面的还不清楚.图与模型的加载写的不清楚,书上的代码还不能运行=- =,真是BI….咳咳.之后还是开始了查文档,翻博客的填坑之旅
,以下为学习总结.

快速应用

存储与加载,简单示例

# 一般而言我们是构建模型之后,session运行,但是这次不同之处在于我们是构件好之后存储了模型
# 然后在session中加载存储好的模型,再运行
import tensorflow as tf
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name='v1')
v2 = tf.Variable(tf.random_normal([2, 3]), name='v2')
init_op = tf.global_variables_initializer() # 初始化全部变量
# saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()
# 只存储图
if not os.path.exists('save/model.meta'):
    saver.export_meta_graph('save/model.meta')


print()
with tf.Session() as sess:
    sess.run(init_op)
    print('v1:', sess.run(v1)) # 打印v1、v2的值一会读取之后对比
    print('v2:', sess.run(v2))
    saver_path = saver.save(sess, 'save/model.ckpt')  # 将模型保存到save/model.ckpt文件
    print('Model saved in file:', saver_path)

print()
with tf.Session() as sess:
    saver.restore(sess, 'save/model.ckpt') # 即将固化到硬盘中的模型从保存路径再读取出来,这样就可以直接使用之前训练好,或者训练到某一阶段的的模型了
    print('v1:', sess.run(v1)) # 打印v1、v2的值和之前的进行对比
    print('v2:', sess.run(v2))
    print('Model Restored')

print()
# 只加载图,
saver = tf.train.import_meta_graph('save/model.ckpt.meta')
with tf.Session() as sess:
    saver.restore(sess, 'save/model.ckpt')
    # 通过张量的名称来获取张量,也可以直接运行新的张量
    print('v1:', sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))
    print('v2:', sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))

运行结果:


v1: [[-0.78213912 -0.72646964]]
v2: [[-0.36301413 -0.99892306  0.21593148]
 [-1.09692276 -0.06931346  0.19474344]]
Model saved in file: save/model.ckpt

v1: [[-0.78213912 -0.72646964]]
v2: [[-0.36301413 -0.99892306  0.21593148]
 [-1.09692276 -0.06931346  0.19474344]]
Model Restored

v1: [[-0.78213912 -0.72646964]]
v2: [[-0.36301413 -0.99892306  0.21593148]
 [-1.09692276 -0.06931346  0.19474344]]

构建模型后直接运行的结果,与加载存储的模型,加载存储的图,并哪找张量的名称获取张量并运行的结果是一致的

存储的文件

保存的文件

tf.train.Saver与存储文件的讲解

核心定义

主要类:tf.train.Saver类负责保存和还原神经网络
自动保存为三个文件:模型文件列表checkpoint,计算图结构model.ckpt.meta,每个变量的取值model.ckpt。其中前两个自动生成。
加载持久化图:通过tf.train.import_meta_graph(“save/model.ckpt.meta”)加载持久化的图

存储文件的讲解

这段代码中,通过saver.save函数将TensorFlow模型保存到了model/model.ckpt文件中,这里代码中指定路径为”save/model.ckpt”,也就是保存到了当前程序所在文件夹里面的save文件夹中。

TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在
checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState
Protocol Buffer.

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef
Protocol Buffer定义的。MetaGraphDef
中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef
信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice
Protocol
Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,请自查。

保存图与模型进阶

按迭代次数保存

# 在1000次迭代时存储
saver.save(sess, 'my_test_model',global_step=1000)

运行结果:

my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint

按时间保存

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

更详细的解释

其实更详细的解释就在源码之中,这些英语还是简单,我相信以大家的水平应该都能看得懂。就不侮辱大家的智商。

  def __init__(self,
               var_list=None,
               reshape=False,
               sharded=False,
               max_to_keep=5,
               keep_checkpoint_every_n_hours=10000.0,
               # 默认时间是一万小时,有趣
               # 但我们只争朝夕
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False):
    """Creates a `Saver`.

    The constructor adds ops to save and restore variables.

    `var_list` specifies the variables that will be saved and restored. It can
    be passed as a `dict` or a list:

    * A `dict` of names to variables: The keys are the names that will be
      used to save or restore the variables in the checkpoint files.
    * A list of variables: The variables will be keyed with their op name in
      the checkpoint files.

    For example:

    ```python
    v1 = tf.Variable(..., name='v1')
    v2 = tf.Variable(..., name='v2')

    # Pass the variables as a dict:
    saver = tf.train.Saver({'v1': v1, 'v2': v2})

    # Or pass them as a list.
    saver = tf.train.Saver([v1, v2])
    # Passing a list is equivalent to passing a dict with the variable op names
    # as keys:
    saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
    ```

    The optional `reshape` argument, if `True`, allows restoring a variable from
    a save file where the variable had a different shape, but the same number
    of elements and type.  This is useful if you have reshaped a variable and
    want to reload it from an older checkpoint.

    The optional `sharded` argument, if `True`, instructs the saver to shard
    checkpoints per device.

    Args:
      var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
        names to `SaveableObject`s. If `None`, defaults to the list of all
        saveable objects.
      reshape: If `True`, allows restoring parameters from a checkpoint
        where the variables have a different shape.
      sharded: If `True`, shard the checkpoints, one per device.
      max_to_keep: Maximum number of recent checkpoints to keep.
        Defaults to 5.
      keep_checkpoint_every_n_hours: How often to keep checkpoints.
        Defaults to 10,000 hours.
      name: String.  Optional name to use as a prefix when adding operations.
      restore_sequentially: A `Bool`, which if true, causes restore of different
        variables to happen sequentially within each device.  This can lower
        memory usage when restoring very large models.
      saver_def: Optional `SaverDef` proto to use instead of running the
        builder. This is only useful for specialty code that wants to recreate
        a `Saver` object for a previously built `Graph` that had a `Saver`.
        The `saver_def` proto should be the one returned by the
        `as_saver_def()` call of the `Saver` that was created for that `Graph`.
      builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
        Defaults to `BaseSaverBuilder()`.
      defer_build: If `True`, defer adding the save and restore ops to the
        `build()` call. In that case `build()` should be called before
        finalizing the graph or using the saver.
      allow_empty: If `False` (default) raise an error if there are no
        variables in the graph. Otherwise, construct the saver anyway and make
        it a no-op.
      write_version: controls what format to use when saving checkpoints.  It
        also affects certain filepath matching logic.  The V2 format is the
        recommended choice: it is much more optimized than V1 in terms of
        memory required and latency incurred during restore.  Regardless of
        this flag, the Saver is able to restore from both V2 and V1 checkpoints.
      pad_step_number: if True, pads the global step number in the checkpoint
        filepaths to some fixed width (8 by default).  This is turned off by
        default.
      save_relative_paths: If `True`, will write relative paths to the
        checkpoint state file. This is needed if the user wants to copy the
        checkpoint directory and reload from the copied directory.

    Raises:
      TypeError: If `var_list` is invalid.
      ValueError: If any of the keys or values in `var_list` are not unique.
    """
1
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:253725次
    • 积分:3076
    • 等级:
    • 排名:第12014名
    • 原创:93篇
    • 转载:20篇
    • 译文:2篇
    • 评论:35条
    博客专栏
    最新评论