tensorflow学习系列六:mnist从训练保存模型再到加载模型测试

    通过前面几个系列的学习对tensorflow有了一个渐渐亲切的感觉,本文主要是从tensorflow模型训练与验证的模型进行实践一遍,以至于我们能够通过tensorflow的训练有一个整体的概念。下面主要是从训练到保存模型,然后加载模型进行预测。

# -*- coding: utf-8 -*-
"""
Created on Mon Jun 11 22:17:52 2018

func:搭建网络图

@author: kuangyongjian
"""
import tensorflow as tf


#构建图
class Network(object):
    
    def __init__(self):
        
        self.learning_rate = 0.001
        #几率已经训练的次数
        self.global_step = tf.Variable(0,trainable = False)
        
        self.x = tf.placeholder(tf.float32,[None,784])
        self.label = tf.placeholder(tf.float32,[None,10])
        
        self.w = tf.Variable(tf.zeros([784,10]))
        self.b = tf.Variable(tf.zeros([10]))
        self.y = tf.nn.softmax(tf.matmul(self.x,self.w) + self.b)
        
        self.loss = -tf.reduce_mean(self.label * tf.log(self.y) + 1e-10)
        
        self.train = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss,global_step = self.global_step)
        
        predict = tf.equal(tf.argmax(self.label,1),tf.argmax(self.y,1))
        self.accuracy = tf.reduce_mean(tf.cast(predict,tf.float32))
    
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 12 09:16:52 2018

func:网络训练,以及对应的模型保存

@author: kuangyongjian
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from model import Network

CKPT_DIR = 'ckpt'

class Train(object):
    
    def __init__(self):
        self.net = Network()
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        self.data = input_data.read_data_sets('../data_set',one_hot = True)
        
    def train(self):
        batch_size = 64
        train_step = 10000
        step = 0
        #每隔1000步保存一次模型
        save_interval = 1000
        
        #tf.train.Saver用于保存训练的结果
        #max to keep 用于设置最多保存多少个模型
        #如果保存的模型超过这个值,最旧的模型被删除
        saver = tf.train.Saver(max_to_keep = 10)
        
        ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
        if ckpt and ckpt.get_checkpoint_state(CKPT_DIR):
            saver.restore(self.sess,ckpt.model_checkpoint_path)
            #读取网络中的global_step的值,即当前已经训练的次数
            step = self.sess.run(self.net.global_step)
            print('continue from')
            print('  -> Minibatch update : ',step)
            
        while step < train_step:
            x,label = self.data.train.next_batch(batch_size)
            _,loss = self.sess.run([self.net.train,self.net.loss],
                                   feed_dict = {self.net.x: x,self.net.label:label})
            
            step = self.sess.run(self.net.global_step)
            if step % 1000 == 0:
                print('第%6d步,当前loss: %.3f'%(step,loss))
                
            #模型保存在ckpt文件夹下
            #模型文件名最后会增加global_step的值,比如2000的模型文件名为model-2000
            if step % save_interval == 0:
                saver.save(self.sess,CKPT_DIR + '/model',global_step = step)
    
    def calculate_accuracy(self):
        test_x = self.data.test.images
        test_label = self.data.test.labels
        acc = self.sess.run(self.net.accuracy,feed_dict = {self.net.x:test_x,self.net.label:test_label})
        
        print("准确率: %.3f,共测试了%d张图片 " % (acc, len(test_label)))
            
                
if __name__ == '__main__':
    model = Train()
    model.train()
    model.calculate_accuracy()
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 12 09:36:55 2018

func:加载模型,进行模型测试

@author: kuangyongjian
"""
import tensorflow as tf
import numpy as np
from PIL import Image
from model import Network

CKPT_DIR = 'ckpt'


class Predict(object):
    
    def __init__(self):
        #清除默认图的堆栈,并设置全局图为默认图
        #若不进行清楚则在第二次加载的时候报错,因为相当于重新加载了两次
        tf.reset_default_graph() 
        self.net = Network()
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        
        #加载模型到sess中
        self.restore()
        print('load susess')
    
    def restore(self):
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
        print(ckpt.model_checkpoint_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(self.sess,ckpt.model_checkpoint_path)
        else:
            raise FileNotFoundError('未保存模型')
        
    def predict(self,image_path):
        #读取图片并灰度化
        img = Image.open(image_path).convert('L')
        flatten_img = np.reshape(img,784)
        x = np.array([1 - flatten_img])
        y = self.sess.run(self.net.y,feed_dict = {self.net.x:x})
        
        print(image_path)
        print(' Predict digit',np.argmax(y[0]))
        
        
if __name__ == '__main__':
    model = Predict()
    model.predict('0.png')
    model.predict('../test_images/1.png')
    model.predict('../test_images/4.png')

注意文中保存模型和加载模型的方式,特别是在加载模型的时候比较容易出错。

若有不当之处请指教,谢谢!

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值