tensorflow模型的保存与恢复
tensorflow的模型保存与恢复功能设计的比较友好,我们只需要3-5行代码就可以实现此功能。保存模型的函数比较简单,从实
例代码中大家就能理解。
重点说一下恢复模型中的一个函数 ckpt=tf.train.get_checkpoint_state(log)
log代表模型的地址,tf.train.get_checkpoint_state()这个函数是通过checkpoint文件找到模型文件名。
它分别有两个返回值:**ckpt.model_checkpoint_path**和**ckpt.all_model_checkpoint_path**。
使用**ckpt.model_checkpoint_path**返回值时代表只恢复最后一个储存的模型。
使用**ckpt.all_model_checkpoint_path**返回值时代表恢复所有模型,具体恢复哪个需要根据自己需求编程实现。
具体用法会在实例代码中体现。
下面我用LeNet-5模型和mnist手写数字数据集举例子,上代码~~~~~~~~~
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
import os
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
log = 'D:/xx/xx/' #定义需要保存的地址
sess = tf.InteractiveSession()
#定义神经网络
def lenet5():
#~~~~~~~~
#创建一个Saver对象,后续要用这个对象来保存模型
saver = tf.train.Saver()
#开始训练
for i in range(1000):
train_step, x, y_true, h_fc3,correct_prediction,loss,learning_rate=lenet5()
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
batch = mnist.train.next_batch(32)
lossES,_=sess.run([loss,train_step], feed_dict={x: batch[0], y_true: batch[1]})
if i%100 == 0:
train_accuracy = accuracy.eval(session=sess, feed_dict={x: batch[0], y_true: batch[1]})
print('step {}, training accuracy: {},loss:{}'.format(i, train_accuracy,lossES))
check = os.path.join(log, "model.ckpt") #打开log文件夹,定义模型以model.ckpt名字进行保存
saver.save(sess, check, global_step=i) #使用Saver的对象saver来进行保存操作。注意这里,设置global=i的意思是每100次保存的模型单独储存,如果不设置的话,每次保存的模型会顶替上一次保存的。
#储存完毕!
####################下面说说怎么恢复模型#########################
import cv2
import numpy as np
import tensorflow as tf
from lenet5 import * #保存的lenet5模型文件
img = cv2.imread('2.png')
img=gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
ret3,img2 = cv2.threshold(img,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
img2=img2.reshape(1,784)
#第一步要恢复神经网络外壳
train_step, x, y_true, h_fc3,correct_prediction,loss,learning_rate=lenet5()
#第二部再次定义一个Saver对象,用于恢复模型
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(log)
saver.restore(sess, ckpt.model_checkpoint_path)#使用了ckpt.model_checkpoint_path返回值,只恢复最后模型的权重
#这就恢复好了,说明在神经网路外壳里已经填好了权重值。这时候lenet5()函数的返回值就都可以算出来了。
#下面算一下h_fc3这个返回值
h_fc3 = sess.run(h_fc3, feed_dict={x: img2})
#恢复完毕。读者可以根据自己的需求计算不同的返回值,查看分类效果。
看一下保存的模型。
tensoflow保存模型分为4部分:
checkpoint文件:记录着已经最新的保存的模型文件。
model.ckpt.data-00000-of-00001文件:保存着模型的所有变量的值。
model.ckpt.index文件:为一个string-string table,table的key值的tensor名,value为BundleEntryProto, BundleEntryProto.
model.ckpt.meta文件:保存着完整的 TensorFlow 图的协议缓存区,即所有的变量,操作,集合等等。