tensorflow1.x手写数字识别

import tensorflow as tf
#minist引入数据方法
from tensorflow.examples.tutorials.mnist import input_data
import random
import matplotlib.pyplot as plt
tf.set_random_seed(777) #设置随机种子

# The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)
# 10,000 points of test data (mnist.test), and
# 5,000 points of validation data (mnist.validation).

'''
手写数字识别
'''

# Each image is 28 pixels by 28 pixels
#定义数据集,label定义为亚编码格式
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)   #读取数据,设置独热编码one-hot=True,MNIST_data为路径目录,若数据已存在,则直接使用,否则,加载到此路径
nb_classes = 10  # 类别数10
#定义占位符
X = tf.placeholder("float", shape=[None, 784]) #28*28=784像素的图片
# labels是每张图片都对应一个one-hot的10个值的向量
Y = tf.placeholder(tf.float32, [None, nb_classes])
#权重和偏置
W = tf.Variable(tf.random_normal([784, nb_classes]), name='weight')
b = tf.Variable(tf.random_normal([nb_classes]), name='bias')
#预测模型
# hypothesis = tf.nn.softmax(tf.matmul(X, W) + b)
logits = tf.matmul(X, W) + b
# hypothesis = logits
# y_ = tf.nn.softmax(logits)
#代价或损失函数
# cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(y_), axis=1))
#tf.nn.softmax_cross_entropy_with_logits 1.激活函数softmax  2.求交叉熵
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
# 梯度下降优化器
train = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)
#准确率计算
prediction = tf.argmax(logits, 1)
correct_prediction = tf.equal(prediction, tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#创建会话
sess = tf.Session()
sess.run(tf.global_variables_initializer()) #全局变量初始化
#迭代训练:1个epoch是所有的样本参加一次梯度下降
training_epochs = 15
batch_size = 100  # 批次大小
'''
批量处理
总共将数据拟合15遍55,000个样本
'''
for epoch in range(training_epochs):
    avg_cost = 0#平均损失
    #计算总批次数total_batch:55,000/100 = 550
    total_batch = int(mnist.train.num_examples / batch_size)
    for i in range(total_batch):#total_batch=550
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)#获取下一批次数据集
        c, _ = sess.run([cost, train], feed_dict={X: batch_xs, Y: batch_ys})
        avg_cost += c / total_batch
    # 显示损失值收敛情况
    print(epoch, avg_cost)
#准确率:前5000个测试集数据的准确率
print("Accuracy: ", sess.run(accuracy, feed_dict={X: mnist.test.images[:5000], Y: mnist.test.labels[:5000]}))
#在测试集中随机抽一个样本进行测试,并显示图片
r = random.randint(0, mnist.test.num_examples - 1)  #mnist.test.num_examples测试集样本数量
print("真实值Label: ", sess.run(tf.argmax(mnist.test.labels[r:r + 1], 1))) #axis=1, label=标签y;r:r+1取的是第r个样本
print("预测值Prediction: ", sess.run(tf.argmax(logits, 1), feed_dict={X: mnist.test.images[r:r + 1]}))  # 预测值
plt.imshow(mnist.test.images[r: r + 1].reshape(28, 28), cmap='Greys')  # 显示图片
plt.show()

while True:
    str = input()
    try:
        if str == 'q':
            break
        r = random.randint(0, mnist.test.num_examples - 1)
        print("Label: ", sess.run(tf.argmax(mnist.test.labels[r:r + 1], 1)))
        print("Prediction: ", sess.run(tf.argmax(logits, 1), feed_dict={X: mnist.test.images[r:r + 1]}))
        plt.imshow(mnist.test.images[r:r + 1].reshape(28, 28), cmap='Greys')
        plt.show()
    except:
        continue
'''
0 2.8263026701320304
1 1.06166895113208
2 0.8380613085898487
3 0.7332327307354319
4 0.6692798727750779
5 0.624611818573691
6 0.5911603328856556
7 0.5638689751245757
8 0.5417451611703092
9 0.522673569335179
10 0.506782321062955
11 0.49244763639840267
12 0.47995582400397785
13 0.4688936629078603
14 0.458703470866789
'''
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值