MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片
下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。这样的切分很重要,在机器学习模型设计时必须有一个单独的测试数据集不用于训练而是用来评估这个模型的性能,从而更加容易把设计的模型推广到其他数据集上(泛化)。
正如前面提到的一样,每一个MNIST数据单元有两部分组成:一张包含手写数字的图片和一个对应的标签。我们把这些图片设为“xs”,把这些标签设为“ys”。训练数据集和测试数据集都包含xs和ys,比如训练数据集的图片是 mnist.train.images ,训练数据集的标签是 mnist.train.labels
详情文档:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
简易代码实现手写数字识别展示
import tensorflow as tf
import random
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# 读取数据
mnist = input_data.read_data_sets('MNIST.data',one_hot=True)
# 占位
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
# 变量
w = tf.Variable(tf.random_normal([784,10]))
b = tf.Variable(tf.random_normal([10]))
# 前向传播
z = tf.matmul(x,w) + b
a = tf.nn.softmax(z)
# 代价
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(a),axis=1))
# 反向传播
dz = a - y
dw = tf.matmul(tf.transpose(x),dz) / tf.cast(tf.shape(x)[0],tf.float32)
db = tf.reduce_mean(dz,0)
# 更新参数
alpha = 0.01
update = [
tf.assign(w,w-alpha*dw),
tf.assign(b,b-alpha*db),
]
# 开启会话
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#周期
training = 15
# 批次
size = 100
for i in range(training):
# 初始化代价
avg_cvl = 0
# 每批次多少个
batch = int(mnist.train.num_examples/size)
# 迭代代价
for k in range(batch):
batch_x,batch_y = mnist.train.next_batch(size)
c,_ = sess.run([cost,update],feed_dict={x:batch_x,y:batch_y})
avg_cvl += c/batch
print(avg_cvl)
# 准确率
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(a,1),tf.argmax(y,1)),tf.float32))
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
# 随机数
r = random.randint(0,mnist.test.num_examples-1)
# 标签随机图片下标
print(sess.run(tf.argmax(mnist.test.labels[r:r+1],1)))
# 预测图片下标
print(sess.run(tf.argmax(a,1),feed_dict={x:mnist.test.images[r:r+1]}))
# 显示图片
plt.imshow(mnist.test.images[r:r+1].reshape(28,28),cmap='cool')
plt.show()
效果展示
0 8.499918651580805
1 4.4614689779281616
2 3.049012235294685
3 2.3698268140446057
4 1.9851832712780337
5 1.7402232361923566
6 1.569694124135103
7 1.4432838637178587
8 1.3454034652493227
9 1.2667941193147143
10 1.2023726383122533
11 1.148049990697339
12 1.102101464108987
13 1.0621832712130113
14 1.027214657826857
accuracy 0.7913
Label: [9]
Prediction: [9]