模型的存储和加载
训练好一个神经网络后,我们希望能够将其应用在预测数据上。那么,如何把模型存储起来呢?同时,对于一个已经存储起来的模型,在将其应用在预测数据上时又如何加载呢?
Tensorflow的API提供了以下两种方式来存储和加载模型
(1)生成检查点文件,扩展为一般为.ckpt,通过在tf.train.Saver对象上调用Saver.save()生成。它包含权重和其他在程序中定义的变量,不包含图结构。如果需要在另一个程序中使用,需要重新创建图形结构,并告诉Tensorflow如何处理这些权重。
(2)生成图协议文件,这是一个二进制文件,扩展名一般为.pb,用tf.train.write_graph()保存,只包含图形结构,不包含权重,然后使用tf.import_graph_def()来加载图形。
1.模型的存储与加载
模型存储主要是建立一个tf.train.Saver()来保存变量,并且指定保存的位置,一般模型的扩展名为.ckpt。
存储模型
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
# 1.加载数据
mnist = input_data.read_data_sets('./data/MNIST_data/', one_hot=True)
tr_X, tr_Y, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
# 2.设置占位符
X = tf.placeholder('float', [None, 784])
y = tf.placeholder('float', [None, 10])
# 3.定义权重函数
def init_weights(shape):
return tf.Variable(tf.random_normal(shape=shape, stddev=0.01))
# 4.初始化权重参数
w_h = init_weights([784, 625])
w_h2 = init_weights([625, 625])
w_o = init_weights([625, 10])
# 5.定义网络模型
def model(X, w_h, w_h2, w_o, p_keep_input, p_keep_hidden):
# *1.第一个全连接层
X = tf.nn.dropout(X, p_keep_input)
h = tf.nn.relu(tf.matmul(X, w_h))
h = tf.nn.dropout(h, p_keep_hidden)
# *2.第二个全连接层
h2 = tf.nn.relu(tf.matmul(h, w_h2))
h2 = tf.nn.dropout(h2, p_keep_hidden)
return tf.matmul(h2, w_o)
# 6.生成网络模型,得到预测值
p_keep_input = tf.placeholder('float')
p_keep_hidden = tf.placeholder('float')
py_x = model(X, w_h, w_h2, w_o, p_keep_input, p_keep_hidden)
# 7.定义损失函数
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.arg_max(py_x, 1)
ckpt_dir = './data/ckpt_dir'
if not os.path.exists(ckpt_dir): # 判断文件夹是否存在
os.makedirs(ckpt_dir) # 创建新的文件夹
saver = tf.train.Saver() # 创建节点保存器
# non_storable_variable = tf.Variable(777) # 可有可无
global_step = tf.Variable(0, name='global_step', trainable=False)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables()) # 初始化之前声明的变量
start = global_step.eval() # 获取global_step中的值,也就是初始值0.
print('Start from:', start)
for i in range(start, 100):
for start, end in zip(range(0, len(tr_X), 128), range(128, len(tr_X)+1, 128)):
sess.run(train_op, feed_dict={X:tr_X[start:end], y:tr_Y[start:end], p_keep_input:0.8, p_keep_hidden:0.5})
global_step.assign(i).eval()
saver.save(sess, ckpt_dir+'/model.ckpt', global_step=global_step)
模型的加载
# 模型的加载
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
2.图的存储与加载
当仅保存图模型时,才将图写入二进制协议文件中。
# 图的储存
# 当且保存图模型的时候,才将图写入二进制文件中
v = tf.Variable(0, name='chaucer')
sess = tf.Session()
tf.train.write_graph(sess.graph_def, './data/tmp/tfmodel','train_mnist.pbtxt')
图的读取,当读取时,从协议文件中都取出来。
# 图的读取
# 当读取时,从协议文件中都取出来
with tf.Session() as sess:
with open('./data/tmp/tfmodel/train_mnist.pbtxt', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_sess.graph.as_default()
tf.import_graph_def(graph_def, name='tfgraph')
参考:
Tensorflow技术解析与实战 李嘉璇