import os
import time
import numpy as np
import tensorflow.compat.v1 as tf
import matplotlib.pyplot as plt
import tensorflow.examples.tutorials.mnist.input_data as input_data
tf.disable_v2_behavior()
plt.rcParams["font.sans-serif"] = ["simhei"]
### 程序执行开始时间
begintime=time.time()
### 读取数据
mnist=input_data.read_data_sets("data/",one_hot=True)
### 定义参数
batch_size=200 # 批数据大小
train_epochs=100 # 迭代轮次
learning_rate=0.001 # 学习率
display_step=10 # 显示粒度
total_batch=int(mnist.train.num_examples/batch_size) # 一轮训练有多少批次
### 定义变量
W=tf.Variable(tf.random.normal([784,10],mean=0.0,stddev=1.0,dtype=tf.float32),name="W")
B=tf.Variable(tf.zeros([10]),dtype=tf.float32,name="B")
x=tf.placeholder(tf.float32,[None,784],name="X")
y=tf.placeholder(tf.float32,[None,10],name="Y")
### 定义模型
forward=tf.matmul(x,W)+B
pred=tf.nn.softmax(forward)
loss=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
### tf.argmax用法
"""
>>> array=np.array([[1,2,3],[3,2,1],[7,8,9],[9,8,7]])
>>> array
array([[1, 2, 3],
[3, 2, 1],
[7, 8, 9],
[9, 8, 7]])
>>> tf.argmax(array,0) 取出行位置
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([3, 2, 2])>
>>> tf.argmax(array,1) 取出列位置
<tf.Tensor: shape=(4,), dtype=int64, numpy=array([2, 0, 2, 0])>
>>> tf.argmax(array,-1) 取出列位置(最后一维)
<tf.Tensor: shape=(4,), dtype=int64, numpy=array([2, 0, 2, 0])>
"""
### 初始化
sess=tf.Session()
init=tf.global_variables_initializer()
sess.run(init)
### 训练
for epoch in range(train_epochs):
for step in range(total_batch):
xs,ys=mnist.train.next_batch(batch_size=batch_size)
sess.run(optimizer,feed_dict={x:xs,y:ys})
loss_epoch,acc=sess.run([loss,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
if (epoch+1) % display_step == 0:
print("训练轮次:{:0>5d},损失率:{:.10f},准确率:{:.5f}".format(epoch+1,loss_epoch,acc))
train_acc=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
valid_acc=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
test_acc =sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("训练集准确率:{:.5f}".format(train_acc))
print("验证集准确率:{:.5f}".format(valid_acc))
print("测试集准确率:{:.5f}".format(test_acc ))
print("程序运行耗时:{:.10f}秒".format(time.time()-begintime))
def plot_images_labels_prediction(images,labels,prediction,index,num=10):
fig=plt.gcf()
fig.set_size_inches(10,12)
if num>25:
num=25
for i in range(num):
ax=plt.subplot(5,5,i+1)
ax.imshow(np.reshape(images[index],(28,28)),cmap="binary")
title=u"标签="+str(np.argmax(labels[index]))
if len(prediction)>0:
title+=",预测值="+str(prediction[index])
ax.set_title(title,fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
index+=1
plt.show()
while True:
os.system("clear")
print("1.........训练数据")
print("2.........识别数字")
menu_input=input("请输入选择(Quit/Q退出):")
if menu_input == "1":
### 训练
for epoch in range(train_epochs):
for step in range(total_batch):
xs,ys=mnist.train.next_batch(batch_size=batch_size)
sess.run(optimizer,feed_dict={x:xs,y:ys})
loss_epoch,acc=sess.run([loss,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
if (epoch+1) % display_step == 0:
print("训练轮次:{:0>5d},损失率:{:.10f},准确率:{:.5f}".format(epoch+1,loss_epoch,acc))
train_acc=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
valid_acc=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
test_acc =sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("训练集准确率:{:.5f}".format(train_acc))
print("验证集准确率:{:.5f}".format(valid_acc))
print("测试集准确率:{:.5f}".format(test_acc ))
print("程序运行耗时:{:.10f}秒".format(time.time()-begintime))
elif menu_input == "2":
myindex=int(input("请输入索引数字(5-9900之间):"))
mysum=int(input("请输入显示数量(5-25之间):"))
if myindex<=5 or myindex>=9900:
myindex=7425
if mysum<=5 or mysum>=25:
mysum=25
prediction_result=sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,myindex,mysum)
elif menu_input.upper() == "QUIT" or menu_input.upper() == "Q":
break
手写数字识别示例三(tensorflow1.0+图示)
最新推荐文章于 2021-12-24 17:04:35 发布