tensorflow使用笔记(四)
tensorflow使用笔记(一)Session的两种使用方式和tensorflow中的变量
tensorflow使用笔记(二)简单神经网络模型的搭建
tensorflow使用笔记(三)tensorboard可视化
tensorflow使用笔记(四)模型的保存和重载
模型的保存
在tensorflow中,模型保存后会有四个文件:(保存模型的文件路径:G:/project/python/MNIST/demo/test_model
)
文件结构:
checkpoint
该文件记录了保存的最新的checkpoint文件和其他的checkpoint文件列表
meta
该文件保存的是图结构,保存了变量、操作、结合等
data、 index
这两个文件保存了所有的变量,在tensorflow0.11版本之前是一个文件(ckpt),现在是两个
保存模型
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 定义两个变量
W = tf.Variable([1, 1], dtype=tf.float32, name='weights')
b = tf.Variable(1, dtype=tf.float32, name='b')
# 初始化变量
init = tf.global_variables_initializer()
# 构建Session
with tf.Session() as sess:
sess.run(init)
# 保存模型
saver = tf.train.Saver()
# 保存模型到 G:/project/python/MNIST/demo/test_model/model1
model1 = saver.save(sess, save_path='G:/project/python/MNIST/demo/test_model/model1')
保存完就产生四个文件:
接下来我们就可以重新加载模型,读取保存的数据等;
我们也可以用tensorboard来查看我们保存的图结构:(不会使用tensorboard看这里)
import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 定义两个变量
with tf.name_scope('variable'):
W = tf.Variable([1, 1], dtype=tf.float32, name='weights')
b = tf.Variable(1, dtype=tf.float32, name='b')
# 初始化变量
init = tf.global_variables_initializer()
# 构建Session
with tf.Session() as sess:
sess.run(init)
# 保存模型
saver = tf.train.Saver()
print(W.name)
# 这是我们用tensorboard来看看我们的图
write = tf.summary.FileWriter('logs/', sess.graph)
# 保存模型到 G:/project/python/MNIST/demo/test_model/model1
model1 = saver.save(sess, save_path='G:/project/python/MNIST/demo/test_model/model1')
可以看到我们刚保存的变量在图中;
模型重载
import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 构建Session
with tf.Session() as sess:
# 加载图
saver = tf.train.import_meta_graph("G:/project/python/MNIST/demo/test_model/model1.meta")
# 加载数据
saver.restore(sess, tf.train.latest_checkpoint("G:/project/python/MNIST/demo/test_model"))
# 使用图中的数据
graph = tf.get_default_graph()
# 把保存的W 赋给w
w = graph.get_tensor_by_name(name='weights:0')
print(sess.run(w))