import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #图片的路径需要与运行文件的路径在同一级目录,或者也可以使用绝对路径
# 定义一个批次
batch_size = 100
n_batch = mnist.train.num_examples//batch_size # 这里//表示整除
# 定义两个占位符placeholder
x = tf.placeholder(tf.float32, [None, 784]) # 784=28*28 表示手写数据图片的像素个数
y = tf.placeholder(tf.float32, [None, 10]) # 10表示0~9,10分类
# 创建一个简单的神经网络
w = tf.Variable(tf.zeros([784, 10])) # 10个神经元
b = tf.Variable(tf.zeros([10])) # 偏置
prediction = tf.nn.softmax(tf.matmul(x, w)+b)
# 二次代价函数
loss = tf.reduce_mean(tf.square(y - prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
# 初始化变量
init = tf.global_variables_initializer()
# 结果存放在一个布尔类型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # argmax()返回一维张量中最大值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # cast()函数:类型转换
with tf.Session() as sess:
sess.run(init)
for epoch in range(21):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print("Inter:"+str(epoch)+", Testing accuracy:"+str(acc))
运行结果如下:
Extracting MNIST_data\train-images-idx3-ubyte.gz Extracting MNIST_data\train-labels-idx1-ubyte.gz Extracting MNIST_data\t10k-images-idx3-ubyte.gz Extracting MNIST_data\t10k-labels-idx1-ubyte.gz Inter:0, Testing accuracy:0.8309 Inter:1, Testing accuracy:0.8712 Inter:2, Testing accuracy:0.8815 Inter:3, Testing accuracy:0.8883 Inter:4, Testing accuracy:0.8949 Inter:5, Testing accuracy:0.8974 Inter:6, Testing accuracy:0.9002 Inter:7, Testing accuracy:0.9009 Inter:8, Testing accuracy:0.9032 Inter:9, Testing accuracy:0.9046 Inter:10, Testing accuracy:0.9062 Inter:11, Testing accuracy:0.9071 Inter:12, Testing accuracy:0.9075 Inter:13, Testing accuracy:0.9092 Inter:14, Testing accuracy:0.9101 Inter:15, Testing accuracy:0.9114 Inter:16, Testing accuracy:0.9113 Inter:17, Testing accuracy:0.9121 Inter:18, Testing accuracy:0.9136 Inter:19, Testing accuracy:0.9133 Inter:20, Testing accuracy:0.914