The MNIST database的全称是Mixed National Institute of Standards and Technology database是一个手写数字数据库,它有60000个训练样本集和10000个测试样本集。它是NIST数据库的一个子集,可以用来做手写数字识别的训练和测试数据集。可以到官网下载http://yann.lecun.com/exdb/mnist/。这些文件并不是标准的图像格式。这些图像数据都保存在二进制文件中。每个样本图像的宽高为28*28。
使用之前导入数据集mnist = input_data.read_data_sets('MNIST_data',one_hot=True),one_hot最早是数字电路中的一种编码方式,这里可以理解成一个拥有10个元素的行向量,但是只有其中一个元素为1,其余全都是0。
导入之后需要划分数据集的大小,batch_size = 100,n_batch = mnist.train.num_examples // batch_size
然后定义一个简单的网络即可对数据集进行训练和检测:
全部代码:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#输入数据集
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
#输入数据集
batch_size = 100
n_batch = mnist.train.num_examples // batch_size
#定义两个占位符
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
#创建一个简单的神经网络
W = tf.Variable(tf.zeros([784,10]))
B = tf.Variable(tf.zeros([10]))
Result = tf.matmul(x,W) + B
prediction = tf.nn.softmax(Result)
#设置损失函数
loss = tf.reduce_mean(tf.square(y - prediction))
#设置优化器
optimizer = tf.train.GradientDescentOptimizer(0.1)
#最小化代价函数
train = optimizer.minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#存放预测结果
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#设置会话
with tf.Session() as sess:
sess.run(init)
for step in range(51):
for batch in range(n_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x:batch_x,y:batch_y})
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter "+str(step)+" ,Testing Accuarcy "+str(acc))
结果是:(可以看出如果继续训练准确率会继续慢慢提升)
Iter 0 ,Testing Accuarcy 0.7426 Iter 1 ,Testing Accuarcy 0.8336 Iter 2 ,Testing Accuarcy 0.8603 Iter 3 ,Testing Accuarcy 0.871 Iter 4 ,Testing Accuarcy 0.8765 Iter 5 ,Testing Accuarcy 0.8821 Iter 6 ,Testing Accuarcy 0.8848 Iter 7 ,Testing Accuarcy 0.8894 Iter 8 ,Testing Accuarcy 0.8921 Iter 9 ,Testing Accuarcy 0.8949 Iter 10 ,Testing Accuarcy 0.8958 Iter 11 ,Testing Accuarcy 0.8976 Iter 12 ,Testing Accuarcy 0.8984 Iter 13 ,Testing Accuarcy 0.8997 Iter 14 ,Testing Accuarcy 0.9015 Iter 15 ,Testing Accuarcy 0.9013 Iter 16 ,Testing Accuarcy 0.9026 Iter 17 ,Testing Accuarcy 0.9029 Iter 18 ,Testing Accuarcy 0.9043 Iter 19 ,Testing Accuarcy 0.9048 Iter 20 ,Testing Accuarcy 0.9059 Iter 21 ,Testing Accuarcy 0.9063 Iter 22 ,Testing Accuarcy 0.9067 Iter 23 ,Testing Accuarcy 0.9073 Iter 24 ,Testing Accuarcy 0.9074 Iter 25 ,Testing Accuarcy 0.9078 Iter 26 ,Testing Accuarcy 0.9089 Iter 27 ,Testing Accuarcy 0.9091 Iter 28 ,Testing Accuarcy 0.9092 Iter 29 ,Testing Accuarcy 0.9096 Iter 30 ,Testing Accuarcy 0.9106 Iter 31 ,Testing Accuarcy 0.911 Iter 32 ,Testing Accuarcy 0.9111 Iter 33 ,Testing Accuarcy 0.9115 Iter 34 ,Testing Accuarcy 0.9122 Iter 35 ,Testing Accuarcy 0.9125 Iter 36 ,Testing Accuarcy 0.9124 Iter 37 ,Testing Accuarcy 0.9134 Iter 38 ,Testing Accuarcy 0.9127 Iter 39 ,Testing Accuarcy 0.9135 Iter 40 ,Testing Accuarcy 0.9133 Iter 41 ,Testing Accuarcy 0.9139 Iter 42 ,Testing Accuarcy 0.9141 Iter 43 ,Testing Accuarcy 0.9143 Iter 44 ,Testing Accuarcy 0.9147 Iter 45 ,Testing Accuarcy 0.9152 Iter 46 ,Testing Accuarcy 0.9154 Iter 47 ,Testing Accuarcy 0.9157 Iter 48 ,Testing Accuarcy 0.9157 Iter 49 ,Testing Accuarcy 0.9158 Iter 50 ,Testing Accuarcy 0.9159