通过前面几个系列的学习对tensorflow有了一个渐渐亲切的感觉,本文主要是从tensorflow模型训练与验证的模型进行实践一遍,以至于我们能够通过tensorflow的训练有一个整体的概念。下面主要是从训练到保存模型,然后加载模型进行预测。
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 11 22:17:52 2018
func:搭建网络图
@author: kuangyongjian
"""
import tensorflow as tf
#构建图
class Network(object):
def __init__(self):
self.learning_rate = 0.001
#几率已经训练的次数
self.global_step = tf.Variable(0,trainable = False)
self.x = tf.placeholder(tf.float32,[None,784])
self.label = tf.placeholder(tf.float32,[None,10])
self.w = tf.Variable(tf.zeros([784,10]))
self.b = tf.Variable(tf.zeros([10]))
self.y = tf.nn.softmax(tf.matmul(self.x,self.w) + self.b)
self.loss = -tf.reduce_mean(self.label * tf.log(self.y) + 1e-10)
self.train = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss,global_step = self.global_step)
predict = tf.equal(tf.argmax(self.label,1),tf.argmax(self.y,1))
self.accuracy = tf.reduce_mean(tf.cast(predict,tf.float32))
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 12 09:16:52 2018
func:网络训练,以及对应的模型保存
@author: kuangyongjian
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from model import Network
CKPT_DIR = 'ckpt'
class Train(object):
def __init__(self):
self.net = Network()
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
self.data = input_data.read_data_sets('../data_set',one_hot = True)
def train(self):
batch_size = 64
train_step = 10000
step = 0
#每隔1000步保存一次模型
save_interval = 1000
#tf.train.Saver用于保存训练的结果
#max to keep 用于设置最多保存多少个模型
#如果保存的模型超过这个值,最旧的模型被删除
saver = tf.train.Saver(max_to_keep = 10)
ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
if ckpt and ckpt.get_checkpoint_state(CKPT_DIR):
saver.restore(self.sess,ckpt.model_checkpoint_path)
#读取网络中的global_step的值,即当前已经训练的次数
step = self.sess.run(self.net.global_step)
print('continue from')
print(' -> Minibatch update : ',step)
while step < train_step:
x,label = self.data.train.next_batch(batch_size)
_,loss = self.sess.run([self.net.train,self.net.loss],
feed_dict = {self.net.x: x,self.net.label:label})
step = self.sess.run(self.net.global_step)
if step % 1000 == 0:
print('第%6d步,当前loss: %.3f'%(step,loss))
#模型保存在ckpt文件夹下
#模型文件名最后会增加global_step的值,比如2000的模型文件名为model-2000
if step % save_interval == 0:
saver.save(self.sess,CKPT_DIR + '/model',global_step = step)
def calculate_accuracy(self):
test_x = self.data.test.images
test_label = self.data.test.labels
acc = self.sess.run(self.net.accuracy,feed_dict = {self.net.x:test_x,self.net.label:test_label})
print("准确率: %.3f,共测试了%d张图片 " % (acc, len(test_label)))
if __name__ == '__main__':
model = Train()
model.train()
model.calculate_accuracy()
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 12 09:36:55 2018
func:加载模型,进行模型测试
@author: kuangyongjian
"""
import tensorflow as tf
import numpy as np
from PIL import Image
from model import Network
CKPT_DIR = 'ckpt'
class Predict(object):
def __init__(self):
#清除默认图的堆栈,并设置全局图为默认图
#若不进行清楚则在第二次加载的时候报错,因为相当于重新加载了两次
tf.reset_default_graph()
self.net = Network()
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
#加载模型到sess中
self.restore()
print('load susess')
def restore(self):
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
print(ckpt.model_checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(self.sess,ckpt.model_checkpoint_path)
else:
raise FileNotFoundError('未保存模型')
def predict(self,image_path):
#读取图片并灰度化
img = Image.open(image_path).convert('L')
flatten_img = np.reshape(img,784)
x = np.array([1 - flatten_img])
y = self.sess.run(self.net.y,feed_dict = {self.net.x:x})
print(image_path)
print(' Predict digit',np.argmax(y[0]))
if __name__ == '__main__':
model = Predict()
model.predict('0.png')
model.predict('../test_images/1.png')
model.predict('../test_images/4.png')
注意文中保存模型和加载模型的方式,特别是在加载模型的时候比较容易出错。
若有不当之处请指教,谢谢!