手写数字识别示例三(tensorflow1.0+图示)

import os
import time
import numpy as np
import tensorflow.compat.v1 as tf
import matplotlib.pyplot as plt
import tensorflow.examples.tutorials.mnist.input_data as input_data
tf.disable_v2_behavior()
plt.rcParams["font.sans-serif"] = ["simhei"]
### 程序执行开始时间
begintime=time.time()
### 读取数据
mnist=input_data.read_data_sets("data/",one_hot=True)
### 定义参数
batch_size=200   # 批数据大小
train_epochs=100   # 迭代轮次
learning_rate=0.001   # 学习率
display_step=10   # 显示粒度
total_batch=int(mnist.train.num_examples/batch_size)   # 一轮训练有多少批次
### 定义变量
W=tf.Variable(tf.random.normal([784,10],mean=0.0,stddev=1.0,dtype=tf.float32),name="W")
B=tf.Variable(tf.zeros([10]),dtype=tf.float32,name="B")
x=tf.placeholder(tf.float32,[None,784],name="X")
y=tf.placeholder(tf.float32,[None,10],name="Y")
### 定义模型
forward=tf.matmul(x,W)+B
pred=tf.nn.softmax(forward)
loss=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
### tf.argmax用法
"""
>>> array=np.array([[1,2,3],[3,2,1],[7,8,9],[9,8,7]])
>>> array
array([[1, 2, 3],
       [3, 2, 1],
       [7, 8, 9],
       [9, 8, 7]])
>>> tf.argmax(array,0)  取出行位置
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([3, 2, 2])>
>>> tf.argmax(array,1)  取出列位置
<tf.Tensor: shape=(4,), dtype=int64, numpy=array([2, 0, 2, 0])>
>>> tf.argmax(array,-1)  取出列位置(最后一维)
<tf.Tensor: shape=(4,), dtype=int64, numpy=array([2, 0, 2, 0])>
"""
### 初始化
sess=tf.Session()
init=tf.global_variables_initializer()
sess.run(init)
### 训练
for epoch in range(train_epochs):
    for step in range(total_batch):
        xs,ys=mnist.train.next_batch(batch_size=batch_size)
        sess.run(optimizer,feed_dict={x:xs,y:ys})
    loss_epoch,acc=sess.run([loss,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
    if (epoch+1) % display_step == 0:
        print("训练轮次:{:0>5d},损失率:{:.10f},准确率:{:.5f}".format(epoch+1,loss_epoch,acc))
train_acc=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
valid_acc=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
test_acc =sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("训练集准确率:{:.5f}".format(train_acc))
print("验证集准确率:{:.5f}".format(valid_acc))
print("测试集准确率:{:.5f}".format(test_acc ))
print("程序运行耗时:{:.10f}秒".format(time.time()-begintime))
def plot_images_labels_prediction(images,labels,prediction,index,num=10):
    fig=plt.gcf()
    fig.set_size_inches(10,12)
    if num>25:
        num=25
    for i in range(num):
        ax=plt.subplot(5,5,i+1)
        ax.imshow(np.reshape(images[index],(28,28)),cmap="binary")
        title=u"标签="+str(np.argmax(labels[index]))
        if len(prediction)>0:
            title+=",预测值="+str(prediction[index])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
        index+=1
    plt.show()
while True:
    os.system("clear")
    print("1.........训练数据")
    print("2.........识别数字")
    menu_input=input("请输入选择(Quit/Q退出):")
    if menu_input == "1":
        ### 训练
        for epoch in range(train_epochs):
            for step in range(total_batch):
                xs,ys=mnist.train.next_batch(batch_size=batch_size)
                sess.run(optimizer,feed_dict={x:xs,y:ys})
            loss_epoch,acc=sess.run([loss,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
            if (epoch+1) % display_step == 0:
                print("训练轮次:{:0>5d},损失率:{:.10f},准确率:{:.5f}".format(epoch+1,loss_epoch,acc))
        train_acc=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
        valid_acc=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
        test_acc =sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("训练集准确率:{:.5f}".format(train_acc))
        print("验证集准确率:{:.5f}".format(valid_acc))
        print("测试集准确率:{:.5f}".format(test_acc ))
        print("程序运行耗时:{:.10f}秒".format(time.time()-begintime))
    elif menu_input == "2":
        myindex=int(input("请输入索引数字(5-9900之间):"))
        mysum=int(input("请输入显示数量(5-25之间):"))
        if myindex<=5 or myindex>=9900:
            myindex=7425
        if mysum<=5 or mysum>=25:
            mysum=25
        prediction_result=sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})
        plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,myindex,mysum)
    elif menu_input.upper() == "QUIT" or menu_input.upper() == "Q":
        break
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值