5.1 Tensorflow:图与模型的加载与存储

原创 2017年08月04日 12:12:51

前言

自己学Tensorflow,现在看的书是《TensorFlow技术解析与实战》,不得不说这书前面的部分有点坑,后面的还不清楚.图与模型的加载写的不清楚,书上的代码还不能运行=- =,真是BI….咳咳.之后还是开始了查文档,翻博客的填坑之旅
,以下为学习总结.

快速应用

存储与加载,简单示例

# 一般而言我们是构建模型之后,session运行,但是这次不同之处在于我们是构件好之后存储了模型
# 然后在session中加载存储好的模型,再运行
import tensorflow as tf
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name='v1')
v2 = tf.Variable(tf.random_normal([2, 3]), name='v2')
init_op = tf.global_variables_initializer() # 初始化全部变量
# saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()
# 只存储图
if not os.path.exists('save/model.meta'):
    saver.export_meta_graph('save/model.meta')


print()
with tf.Session() as sess:
    sess.run(init_op)
    print('v1:', sess.run(v1)) # 打印v1、v2的值一会读取之后对比
    print('v2:', sess.run(v2))
    saver_path = saver.save(sess, 'save/model.ckpt')  # 将模型保存到save/model.ckpt文件
    print('Model saved in file:', saver_path)

print()
with tf.Session() as sess:
    saver.restore(sess, 'save/model.ckpt') # 即将固化到硬盘中的模型从保存路径再读取出来,这样就可以直接使用之前训练好,或者训练到某一阶段的的模型了
    print('v1:', sess.run(v1)) # 打印v1、v2的值和之前的进行对比
    print('v2:', sess.run(v2))
    print('Model Restored')

print()
# 只加载图,
saver = tf.train.import_meta_graph('save/model.ckpt.meta')
with tf.Session() as sess:
    saver.restore(sess, 'save/model.ckpt')
    # 通过张量的名称来获取张量,也可以直接运行新的张量
    print('v1:', sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))
    print('v2:', sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))

运行结果:


v1: [[-0.78213912 -0.72646964]]
v2: [[-0.36301413 -0.99892306  0.21593148]
 [-1.09692276 -0.06931346  0.19474344]]
Model saved in file: save/model.ckpt

v1: [[-0.78213912 -0.72646964]]
v2: [[-0.36301413 -0.99892306  0.21593148]
 [-1.09692276 -0.06931346  0.19474344]]
Model Restored

v1: [[-0.78213912 -0.72646964]]
v2: [[-0.36301413 -0.99892306  0.21593148]
 [-1.09692276 -0.06931346  0.19474344]]

构建模型后直接运行的结果,与加载存储的模型,加载存储的图,并哪找张量的名称获取张量并运行的结果是一致的

存储的文件

保存的文件

tf.train.Saver与存储文件的讲解

核心定义

主要类:tf.train.Saver类负责保存和还原神经网络
自动保存为三个文件:模型文件列表checkpoint,计算图结构model.ckpt.meta,每个变量的取值model.ckpt。其中前两个自动生成。
加载持久化图:通过tf.train.import_meta_graph(“save/model.ckpt.meta”)加载持久化的图

存储文件的讲解

这段代码中,通过saver.save函数将TensorFlow模型保存到了model/model.ckpt文件中,这里代码中指定路径为”save/model.ckpt”,也就是保存到了当前程序所在文件夹里面的save文件夹中。

TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在
checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState
Protocol Buffer.

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef
Protocol Buffer定义的。MetaGraphDef
中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef
信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice
Protocol
Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,请自查。

保存图与模型进阶

按迭代次数保存

# 在1000次迭代时存储
saver.save(sess, 'my_test_model',global_step=1000)

运行结果:

my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint

按时间保存

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

更详细的解释

其实更详细的解释就在源码之中,这些英语还是简单,我相信以大家的水平应该都能看得懂。就不侮辱大家的智商。

  def __init__(self,
               var_list=None,
               reshape=False,
               sharded=False,
               max_to_keep=5,
               keep_checkpoint_every_n_hours=10000.0,
               # 默认时间是一万小时,有趣
               # 但我们只争朝夕
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False):
    """Creates a `Saver`.

    The constructor adds ops to save and restore variables.

    `var_list` specifies the variables that will be saved and restored. It can
    be passed as a `dict` or a list:

    * A `dict` of names to variables: The keys are the names that will be
      used to save or restore the variables in the checkpoint files.
    * A list of variables: The variables will be keyed with their op name in
      the checkpoint files.

    For example:

    ```python
    v1 = tf.Variable(..., name='v1')
    v2 = tf.Variable(..., name='v2')

    # Pass the variables as a dict:
    saver = tf.train.Saver({'v1': v1, 'v2': v2})

    # Or pass them as a list.
    saver = tf.train.Saver([v1, v2])
    # Passing a list is equivalent to passing a dict with the variable op names
    # as keys:
    saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
    ```

    The optional `reshape` argument, if `True`, allows restoring a variable from
    a save file where the variable had a different shape, but the same number
    of elements and type.  This is useful if you have reshaped a variable and
    want to reload it from an older checkpoint.

    The optional `sharded` argument, if `True`, instructs the saver to shard
    checkpoints per device.

    Args:
      var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
        names to `SaveableObject`s. If `None`, defaults to the list of all
        saveable objects.
      reshape: If `True`, allows restoring parameters from a checkpoint
        where the variables have a different shape.
      sharded: If `True`, shard the checkpoints, one per device.
      max_to_keep: Maximum number of recent checkpoints to keep.
        Defaults to 5.
      keep_checkpoint_every_n_hours: How often to keep checkpoints.
        Defaults to 10,000 hours.
      name: String.  Optional name to use as a prefix when adding operations.
      restore_sequentially: A `Bool`, which if true, causes restore of different
        variables to happen sequentially within each device.  This can lower
        memory usage when restoring very large models.
      saver_def: Optional `SaverDef` proto to use instead of running the
        builder. This is only useful for specialty code that wants to recreate
        a `Saver` object for a previously built `Graph` that had a `Saver`.
        The `saver_def` proto should be the one returned by the
        `as_saver_def()` call of the `Saver` that was created for that `Graph`.
      builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
        Defaults to `BaseSaverBuilder()`.
      defer_build: If `True`, defer adding the save and restore ops to the
        `build()` call. In that case `build()` should be called before
        finalizing the graph or using the saver.
      allow_empty: If `False` (default) raise an error if there are no
        variables in the graph. Otherwise, construct the saver anyway and make
        it a no-op.
      write_version: controls what format to use when saving checkpoints.  It
        also affects certain filepath matching logic.  The V2 format is the
        recommended choice: it is much more optimized than V1 in terms of
        memory required and latency incurred during restore.  Regardless of
        this flag, the Saver is able to restore from both V2 and V1 checkpoints.
      pad_step_number: if True, pads the global step number in the checkpoint
        filepaths to some fixed width (8 by default).  This is turned off by
        default.
      save_relative_paths: If `True`, will write relative paths to the
        checkpoint state file. This is needed if the user wants to copy the
        checkpoint directory and reload from the copied directory.

    Raises:
      TypeError: If `var_list` is invalid.
      ValueError: If any of the keys or values in `var_list` are not unique.
    """
版权声明:欢迎转载,共同学习,但请尊重版权,标明出处:http://blog.csdn.net/fontthrone

Tensorflow学习笔记-模型保存与加载

使用Tensorflow训练好模型之后,我们需要将训练好的模型保存起来,方便以后的使用,这就是Tensorflow模型的持久化。保存v1 = tf.Variable(tf.constant(1,sha...
  • lovelyaiq
  • lovelyaiq
  • 2017年11月27日 16:51
  • 621

TensorFlow学习笔记(8)--网络模型的保存和读取

之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用...
  • lwplwf
  • lwplwf
  • 2017年03月16日 11:23
  • 28152

tensorflow学习之识别单张图片的实现(python手写数字)

假设我们已经安装好了tensorflow。 一般在安装好tensorflow后,都会跑它的demo,而最常见的demo就是手写数字识别的demo,也就是mnist数据集。 然而我们仅仅是跑了它的dem...
  • gaohuazhao
  • gaohuazhao
  • 2017年06月06日 19:28
  • 7926

tensorflow问题集合

1、问题:为什么模型文件是三个(.data-00000-of-00001和.index和.meta)而没有.ckpt后缀的文件? 解答:引文新版本的saver会保存成三个后缀的形式,而旧版本的sav...
  • jessir
  • jessir
  • 2017年08月09日 16:12
  • 212

Tensorflow系列——Saver的用法

Saver的用法 1. Saver的背景介绍     我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了S...
  • u011500062
  • u011500062
  • 2016年06月21日 17:17
  • 35263

5.2 TensorFlow:模型的加载,存储,实例

背景之前已经写过TensorFlow图与模型的加载与存储了,写的很详细,但是或闻有人没看懂,所以在附上一个关于模型加载与存储的例子,.其中模型很巧妙,比之前numpy写一大堆简单多了,这样有利于把主要...
  • FontThrone
  • FontThrone
  • 2017年08月12日 13:06
  • 2873

查看tensorflow ckpt文件中的变量名和对应值

查看tf ckpt文件中的变量名和对应值
  • u010698086
  • u010698086
  • 2017年09月09日 17:22
  • 1463

tensorflow学习笔记(十):sess.run()

session.run()session.run([fetch1, fetch2])import tensorflow as tf state = tf.Variable(0.0,dtype=tf.f...
  • u012436149
  • u012436149
  • 2016年10月24日 09:04
  • 29927

TensorFlow-sess.run()

当我们构建完图后,需要在一个会话中启动图,启动的第一步是创建一个Session对象。 为了取回(Fetch)操作的输出内容, 可以在使用 Session 对象的 run()调用执行图时,传入一些 ...
  • laolu1573
  • laolu1573
  • 2017年03月28日 17:02
  • 2719

tensorflow学习笔记(十):sess.run()

session.run() 【2016.12.28.错误更新:之前对sess.run([train_op, loss])理解有误,已更新成正确版本】 session.run([fetch1, fe...
  • oHongHong
  • oHongHong
  • 2017年05月27日 09:40
  • 2676
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:5.1 Tensorflow:图与模型的加载与存储
举报原因:
原因补充:

(最多只允许输入30个字)