# -*- coding: utf-8 -*-
"""
Created on Wed Dec 2 11:49:26 2020
@author: Melinda
"""
'''
train的shape (55000, 784)
test的shape: (10000, 784)
validation的shape: (5000, 784)
'''
'''
im = mnist.train.images[1] #[55000,784]
im = im.reshape(-1, 28)
pylab.imshow(im) #数字转化为图形
pylab.show()
'''
import tensorflow as tf
import pylab #convert ont_hot to image
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True) #28X28 pixel
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
W = tf.Variable(tf.random_normal(([784, 10])))
b = tf.Variable(tf.zeros([10]))
pred = tf.nn.softmax(tf.matmul(x, W) + b)
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
saver = tf.train.Saver()
model_path = "log/521model.ckpt"
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# 运行优化器
_, c = sess.run([optimizer, cost], feed_dict = {x: batch_xs, y: batch_ys})
avg_cost += c / total_batch
if (epoch+1) % display_step == 0:
print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))
print("Finished!")
# -----------------模型训练结束-----------------------
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Accuracy:", accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
# -------------------保存模型------------------------
save_path = saver.save(sess, model_path)
print("Model saved in file: %s" % save_path)
# # 使用保存模型测试,保存模型之后,将上一个session注释,采用此session预测
# print("Starting 2nd session...")
# with tf.Session() as sess:
# sess.run(tf.global_variables_initializer())
# saver.restore(sess, model_path)
# # test model
# correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# print("Test Accuracy:", accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
Mnist手写数字识别代码
最新推荐文章于 2024-05-15 21:55:39 发布