5.2 TensorFlow:模型的加载,存储,实例

原创 2017年08月12日 13:06:24

背景

之前已经写过TensorFlow图与模型的加载与存储了,写的很详细,但是或闻有人没看懂,所以在附上一个关于模型加载与存储的例子,CODE是我偶然看到了,就记下来了.其中模型很巧妙,比之前numpy写一大堆简单多了,这样有利于把主要注意力放在模型的加载与存储上.

解析

创建保存文件的类:saver = tf.train.Saver()

saver = tf.train.Saver() ,即为常见保存模型,图,数据的类,其内部结构在源码中有详细的解释,这个之前的文章已经说过了,这次只讲,我们如何我们具体要用的方法

saver.save() 保存

源码结构

 def save(self,
           sess,
           save_path,
           global_step=None,
           latest_filename=None,
           meta_graph_suffix="meta",
           write_meta_graph=True,
           write_state=True):

# 实际运用 :
# saver = tf.train.Saver()
# saver.save(sess, checkpoint_dir + 'model55.ckpt', global_step=i+1)
# 注意,实际保存时 model55.ckpt 会被保存为多个文件

常用的参数:
1. sess : 要保存的session
2. save_path :保存路径,注意想要保存在代码所在目录下,前面不要加’/’不然会变成根目录
3. global_step :多次迭代时,使用该参数,按照步骤保存
4. 保存文件如下,后面的-50,100,是按照步骤(global_step)保存的
实际存储的文件

调用

源码结构

def restore(self, sess, save_path):

# sess 即为 当前session
# save_path : 与之前保存时的使用的名字一直
# 如果调取上一个例子存储的模型:此时 save_path = checkpoint_dir + 'model55.ckpt' 


# 代码实例 :saver.restore(sess, ckpt.model_checkpoint_path)
  1. saver.restore(),会恢复原来session 中的图,参数,等(也就是相当于直接调用原来训练好的模型),假如你传入的文件夹中存储着多个model.ckpt文件组,那么会默认调用最后存储的ckpt文件组,
  2. ckpt文件组的排序为:当按照步骤排序时,最后保存的步骤为最新,按照时间排序时,同理

ckpt文件

之前已经在原来的文章中写过,这里有必要再发一次

TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在
checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState
Protocol Buffer.

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef
Protocol Buffer定义的。MetaGraphDef
中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef
信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice
Protocol
Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,请自查。

CODE AND RUN

import tensorflow as tf
import numpy as np
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

x = tf.placeholder(tf.float32, shape=[None, 1])
# 拟合 y 
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b

loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = 'save/'

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    if isTrain:
        for i in range(train_steps):
            sess.run(train, feed_dict={x: x_data})
            if (i + 1) % checkpoint_steps == 0:
                saver.save(sess, checkpoint_dir + 'model55.ckpt', global_step=i+1)
                print(sess.run(w))
                print(sess.run(b))
                '''
                运行结果
                [ 3.87540483]
                [ 4.07181311]
                最后训练好的模型跑出来的数据
                [ 3.994277]
                [ 4.00329876]
                '''
    else:
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            pass
        print(sess.run(w))
        print(sess.run(b))
        '''
        [ 3.994277]
        [ 4.00329876]
        '''

最后

更详细的内容,请点击这里

版权声明:转载请标明出处:http://blog.csdn.net/fontthrone,也请保留该信息

Tensorflow系列——Saver的用法

Saver的用法 1. Saver的背景介绍     我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了S...
  • u011500062
  • u011500062
  • 2016年06月21日 17:17
  • 32804

5.1 Tensorflow:图与模型的加载与存储

前言自己学Tensorflow,现在看的书是《TensorFlow技术解析与实战》,不得不说这书前面的部分有点坑,后面的还不清楚.图与模型的加载写的不清楚,书上的代码还不能运行=- =,真是BI….咳...
  • FontThrone
  • FontThrone
  • 2017年08月04日 12:12
  • 1203

查看tensorflow ckpt文件中的变量名和对应值

查看tf ckpt文件中的变量名和对应值
  • u010698086
  • u010698086
  • 2017年09月09日 17:22
  • 1193

tensorflow学习笔记(十):sess.run()

session.run()session.run([fetch1, fetch2])import tensorflow as tf state = tf.Variable(0.0,dtype=tf.f...
  • u012436149
  • u012436149
  • 2016年10月24日 09:04
  • 27495

TensorFlow-sess.run()

当我们构建完图后,需要在一个会话中启动图,启动的第一步是创建一个Session对象。 为了取回(Fetch)操作的输出内容, 可以在使用 Session 对象的 run()调用执行图时,传入一些 ...
  • laolu1573
  • laolu1573
  • 2017年03月28日 17:02
  • 2231

tensorflow学习笔记(十):sess.run()

session.run() 【2016.12.28.错误更新:之前对sess.run([train_op, loss])理解有误,已更新成正确版本】 session.run([fetch1, fe...
  • oHongHong
  • oHongHong
  • 2017年05月27日 09:40
  • 2336

session.run()是非常耗时的,千万不要用session.run的方式去取数据

1、修改某一程序的时候,使用了session去取数据,导致时间效率非常低。后来,对session.run()进行了测试,发现使用session读取数据的效率是非常低下的. # -*- coding:...
  • lujiandong1
  • lujiandong1
  • 2016年12月06日 15:38
  • 5160

TensorFlow 学习(二)—— tf.Session() 与 tf.Session().run()

1. 使用 tf.Session().run() 读取变量的值十分耗时
  • lanchunhui
  • lanchunhui
  • 2017年03月13日 18:19
  • 8835

tensorflow从已经训练好的模型中,恢复(指定)权重(构建新变量、网络)并继续训练(finetuning)

之前已经写了一篇《Tensorflow保存模型,恢复模型,使用训练好的模型进行预测和提取中间输出(特征)》,里面主要讲恢复模型然后使用该模型 假如要保存或者恢复指定tensor,并且把保存的graph...
  • ying86615791
  • ying86615791
  • 2017年07月27日 18:18
  • 3259

TensorFlow保存和加载训练模型

对于机器学习,尤其是深度学习DL的算法,模型训练可能很耗时,几个小时或者几天,所以如果是测试模块出了问题,每次都要重新运行就显得很浪费时间,所以如果训练部分没有问题,那么可以直接将训练的模型保存起来,...
  • JasonZhangOO
  • JasonZhangOO
  • 2017年03月07日 11:13
  • 7371
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:5.2 TensorFlow:模型的加载,存储,实例
举报原因:
原因补充:

(最多只允许输入30个字)