1. 下载MNIST数据集
当我们开始学习编程的时候,第一件事往往是学习打印”Hello World”。就好比编程入门有Hello World,机器学习入门有MNIST。
MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片:
每张图片大小为28*28,展开成一维行向量就是784维,即每张图片就是784维空间中的一个点。
tensorflow提供一个input_data.py文件,专门用于下载mnist数据,我们直接调用就可以了,代码如下:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
执行完成后,会在当前目录下新建一个文件夹MNIST_data, 下载的数据将放入这个文件夹内。下载的四个文件为:
每个子集都由两部分组成:图片部分(images)和标签部分(labels), 我们可以用下面的代码来查看 :
print(mnist.train.images