MNIST数据集是人工智能学习入门的数据集,包含了一系列的手写的数字图片
载入MNIST数据集的方法很简单,Tensorflow集成了载入数据集的方法
首先导入tensorflow模块和matplotlib.pyplot模块,pyplot是为了在显示载入的图片
import tensorflow as tf
import matplotlib.pyplot as plt
然后载入MNIST数据集
# load MNIST from keras datasets
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# train_images: 60000*28*28, train_labels: 60000*1
# test_images: 10000*28*28, test_labels: 10000*1
调用集成的tf.keras.datasets.mnist.load_data函数,会自动从网络下载数据集,并产生train_images, train_labels, test_images, test_labels变量,分别为训练和测试的图片和标记。其中训练集为60000张图片,测试集为10000张图片,每张图片的大小为28*28像素。
通过subplot来显示数据集的一些图片,如下
# lines and columns of subplots
m = 10
n = 10
num = m*n
# size of figure
plt.figure(figsize=(11,11))
# plot first 100 pictures in train images
for i in range(num):
plt.subplot(m,n,i+1)
plt.imshow(train_images[i], cmap='gray_r')
plt.xticks([])
plt.yticks([])
plt.show()
这里分10x10方格,每一个方格显示一张图片,显示训练数据集的前100张图片,结果如下