tensorflow模型的保存与恢复

tensorflow模型的保存与恢复

tensorflow的模型保存与恢复功能设计的比较友好,我们只需要3-5行代码就可以实现此功能。保存模型的函数比较简单,从实
例代码中大家就能理解。

重点说一下恢复模型中的一个函数 ckpt=tf.train.get_checkpoint_state(log)

   log代表模型的地址,tf.train.get_checkpoint_state()这个函数是通过checkpoint文件找到模型文件名。
   它分别有两个返回值:**ckpt.model_checkpoint_path**和**ckpt.all_model_checkpoint_path**。
   
   使用**ckpt.model_checkpoint_path**返回值时代表只恢复最后一个储存的模型。
   使用**ckpt.all_model_checkpoint_path**返回值时代表恢复所有模型,具体恢复哪个需要根据自己需求编程实现。
   
   具体用法会在实例代码中体现。

下面我用LeNet-5模型和mnist手写数字数据集举例子,上代码~~~~~~~~~

import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
import os

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
log = 'D:/xx/xx/'   #定义需要保存的地址
sess = tf.InteractiveSession()
#定义神经网络
def lenet5():
#~~~~~~~~
#创建一个Saver对象,后续要用这个对象来保存模型
saver = tf.train.Saver()
#开始训练 
for i in range(1000):
		train_step, x, y_true, h_fc3,correct_prediction,loss,learning_rate=lenet5()
    	accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
        batch = mnist.train.next_batch(32)
        lossES,_=sess.run([loss,train_step], feed_dict={x: batch[0], y_true: batch[1]})
        if i%100 == 0:
            train_accuracy = accuracy.eval(session=sess, feed_dict={x: batch[0], y_true: batch[1]})
            print('step {}, training accuracy: {},loss:{}'.format(i, train_accuracy,lossES))
            check = os.path.join(log, "model.ckpt") #打开log文件夹,定义模型以model.ckpt名字进行保存
            saver.save(sess, check, global_step=i) #使用Saver的对象saver来进行保存操作。注意这里,设置global=i的意思是每100次保存的模型单独储存,如果不设置的话,每次保存的模型会顶替上一次保存的。
            #储存完毕!
####################下面说说怎么恢复模型#########################
import cv2
import numpy as np
import tensorflow as tf
from lenet5 import * #保存的lenet5模型文件
img = cv2.imread('2.png')
img=gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
ret3,img2 = cv2.threshold(img,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
img2=img2.reshape(1,784)
#第一步要恢复神经网络外壳
train_step, x, y_true, h_fc3,correct_prediction,loss,learning_rate=lenet5()
#第二部再次定义一个Saver对象,用于恢复模型
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(log)
saver.restore(sess, ckpt.model_checkpoint_path)#使用了ckpt.model_checkpoint_path返回值,只恢复最后模型的权重
#这就恢复好了,说明在神经网路外壳里已经填好了权重值。这时候lenet5()函数的返回值就都可以算出来了。
#下面算一下h_fc3这个返回值
h_fc3 = sess.run(h_fc3, feed_dict={x: img2})
#恢复完毕。读者可以根据自己的需求计算不同的返回值,查看分类效果。

看一下保存的模型。
模型列表

tensoflow保存模型分为4部分:
      checkpoint文件:记录着已经最新的保存的模型文件。
      model.ckpt.data-00000-of-00001文件:保存着模型的所有变量的值。
      model.ckpt.index文件:为一个string-string table,table的key值的tensor名,value为BundleEntryProto, BundleEntryProto.
      model.ckpt.meta文件:保存着完整的 TensorFlow 图的协议缓存区,即所有的变量,操作,集合等等。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值