tensorflow训练minist模型

训练好后保存为pb格式的文件:

# -*- coding: utf-8 -*-
"""
Created on Mon Jan  7 18:54:45 2019

@author: judy.yuan
"""
"""
import tensorflow as tf
 
 
def helloFunc():
    print("hellFunc")
 
 
if __name__ == '__main__':
    tf.reset_default_graph()
    hello = tf.Variable(tf.constant('Hello World', name = "hello"))
    #init = tf.initialize_all_variables() #deprecated
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
 
    tf.train.write_graph(sess.graph_def, "modle/","hello_modle.pb", as_text=False)
    saver = tf.train.Saver()
    saver.save(sess, "./modle/hello_model")
"""


#coding=utf-8
# 单隐层SoftMax Regression分类器:训练和保存模型模块
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.python.framework import graph_util
print('tensortflow:{0}'.format(tf.__version__))
 
mnist = input_data.read_data_sets("D:\\Mnist1\\", one_hot=True)
 
#create model
with tf.name_scope('input'):
    x = tf.placeholder(tf.float32,[None,784],name='x_input')#输入节点名:x_input
    y_ = tf.placeholder(tf.float32,[None,10],name='y_input')
with tf.name_scope('layer'):
    with tf.name_scope('W'):
        #tf.zeros([3, 4], tf.int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
        W = tf.Variable(tf.zeros([784,10]),name='Weights')
    with tf.name_scope('b'):
        b = tf.Variable(tf.zeros([10]),name='biases')
    with tf.name_scope('W_p_b'):
        Wx_plus_b = tf.add(tf.matmul(x, W), b, name='Wx_plus_b')
 
    y = tf.nn.softmax(Wx_plus_b, name='final_result')
 
# 定义损失函数和优化方法
with tf.name_scope('loss'):
    loss = -tf.reduce_sum(y_ * tf.log(y))
with tf.name_scope('train_step'):
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    print(train_step)
# 初始化
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)
# 训练
for step in range(100):
    batch_xs,batch_ys =mnist.train.next_batch(100)
    train_step.run({x:batch_xs,y_:batch_ys})
    # variables = tf.all_variables()
    # print(len(variables))
    # print(sess.run(b))
 
# 测试模型准确率
pre_num=tf.argmax(y,1,output_type='int32',name="output")#输出节点名:output
correct_prediction = tf.equal(pre_num,tf.argmax(y_,1,output_type='int32'))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
a = accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})
print('测试正确率:{0}'.format(a))
 
# 保存训练好的模型
#形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output'])
with tf.gfile.FastGFile('D:\\Mnist1\\mnist.pb', mode='wb') as f:#’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
    f.write(output_graph_def.SerializeToString())
sess.close()
1

使用训练好的模型来判断:

# -*- coding: utf-8 -*-
"""
Created on Mon Jan  7 18:53:00 2019

@author: judy.yuan
"""
"""
import tensorflow as tf
 
 
if __name__ == '__main__':
    restore = tf.train.import_meta_graph("hello_model.meta")
    sess = tf.Session()
    restore.restore(sess, "hello_model")
 
 
    print(sess.run(tf.get_default_graph().get_tensor_by_name("hello:0")))
"""


import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
 
#模型路径
model_path = 'D:\\Mnist1\\mnist.pb'
#测试图片
testImage = Image.open("D:\\Mnist1\\test.jpg");
 
with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(model_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")
 
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        # x_test = x_test.reshape(1, 28 * 28)
        input_x = sess.graph.get_tensor_by_name("input/x_input:0")
        output = sess.graph.get_tensor_by_name("output:0")
 
        #对图片进行测试
        testImage=testImage.convert('L')
        testImage = testImage.resize((28, 28))
        test_input=np.array(testImage)
        test_input = test_input.reshape(1, 28 * 28)
        pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果
        print('模型预测结果为:',pre_num)
        #显示测试的图片
        # testImage = test_x.reshape(28, 28)
        fig = plt.figure(), plt.imshow(testImage,cmap='binary')  # 显示图片
        plt.title("prediction result:"+str(pre_num))
        plt.show()

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值