训练好后保存为pb格式的文件:
# -*- coding: utf-8 -*-
"""
Created on Mon Jan 7 18:54:45 2019
@author: judy.yuan
"""
"""
import tensorflow as tf
def helloFunc():
print("hellFunc")
if __name__ == '__main__':
tf.reset_default_graph()
hello = tf.Variable(tf.constant('Hello World', name = "hello"))
#init = tf.initialize_all_variables() #deprecated
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
tf.train.write_graph(sess.graph_def, "modle/","hello_modle.pb", as_text=False)
saver = tf.train.Saver()
saver.save(sess, "./modle/hello_model")
"""
#coding=utf-8
# 单隐层SoftMax Regression分类器:训练和保存模型模块
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.python.framework import graph_util
print('tensortflow:{0}'.format(tf.__version__))
mnist = input_data.read_data_sets("D:\\Mnist1\\", one_hot=True)
#create model
with tf.name_scope('input'):
x = tf.placeholder(tf.float32,[None,784],name='x_input')#输入节点名:x_input
y_ = tf.placeholder(tf.float32,[None,10],name='y_input')
with tf.name_scope('layer'):
with tf.name_scope('W'):
#tf.zeros([3, 4], tf.int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
W = tf.Variable(tf.zeros([784,10]),name='Weights')
with tf.name_scope('b'):
b = tf.Variable(tf.zeros([10]),name='biases')
with tf.name_scope('W_p_b'):
Wx_plus_b = tf.add(tf.matmul(x, W), b, name='Wx_plus_b')
y = tf.nn.softmax(Wx_plus_b, name='final_result')
# 定义损失函数和优化方法
with tf.name_scope('loss'):
loss = -tf.reduce_sum(y_ * tf.log(y))
with tf.name_scope('train_step'):
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
print(train_step)
# 初始化
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)
# 训练
for step in range(100):
batch_xs,batch_ys =mnist.train.next_batch(100)
train_step.run({x:batch_xs,y_:batch_ys})
# variables = tf.all_variables()
# print(len(variables))
# print(sess.run(b))
# 测试模型准确率
pre_num=tf.argmax(y,1,output_type='int32',name="output")#输出节点名:output
correct_prediction = tf.equal(pre_num,tf.argmax(y_,1,output_type='int32'))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
a = accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})
print('测试正确率:{0}'.format(a))
# 保存训练好的模型
#形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output'])
with tf.gfile.FastGFile('D:\\Mnist1\\mnist.pb', mode='wb') as f:#’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
f.write(output_graph_def.SerializeToString())
sess.close()
1
使用训练好的模型来判断:
# -*- coding: utf-8 -*-
"""
Created on Mon Jan 7 18:53:00 2019
@author: judy.yuan
"""
"""
import tensorflow as tf
if __name__ == '__main__':
restore = tf.train.import_meta_graph("hello_model.meta")
sess = tf.Session()
restore.restore(sess, "hello_model")
print(sess.run(tf.get_default_graph().get_tensor_by_name("hello:0")))
"""
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
#模型路径
model_path = 'D:\\Mnist1\\mnist.pb'
#测试图片
testImage = Image.open("D:\\Mnist1\\test.jpg");
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(model_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
tf.global_variables_initializer().run()
# x_test = x_test.reshape(1, 28 * 28)
input_x = sess.graph.get_tensor_by_name("input/x_input:0")
output = sess.graph.get_tensor_by_name("output:0")
#对图片进行测试
testImage=testImage.convert('L')
testImage = testImage.resize((28, 28))
test_input=np.array(testImage)
test_input = test_input.reshape(1, 28 * 28)
pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果
print('模型预测结果为:',pre_num)
#显示测试的图片
# testImage = test_x.reshape(28, 28)
fig = plt.figure(), plt.imshow(testImage,cmap='binary') # 显示图片
plt.title("prediction result:"+str(pre_num))
plt.show()