TensorFlow入门-MNIST入门
本文介绍了如何查看MNIST数据集、如何利用TensorFlow进行MNIST手写数据集的识别、以及如何利用训练好的模型进行数字识别。
项目结构
以下所有代码均在main.py内。
MNIST
import input_data # 需要将文末py文件导入项目
import matplotlib.pyplot as plt # plt 用于显示图片
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 自动加载数据集,数据集在有网络的情况下会自动下载
# 显示mnist图片与标签
border = 1
cur = 0
while cur < border:
mnist_img = mnist.train.images[cur].reshape((28, 28))
mnist_tag = mnist.train.labels[cur]
plt.imshow(mnist_img, cmap='gray') # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()
print(mnist_tag)
cur += 1
显示MNIST图片:
为一个28x28的数组的一纬数组。
标签:
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
Softmax函数
可以返回10个数字类的概率值。
利用Softmax进行MNIST数据识别的过程为:
Softmax识别MNIST手写数据集主模块
# 定义输入占位符,None代表可以输入的图片数为任意值
x = tf.placeholder("float", [None, 784])
# 变量,在训练时可以进行数值修改
W = tf.Variable(tf.zeros([784, 10]))
b