获取用于训练和测试的数据
1、 下载MNIST数据集到本地
不知道MNIST数据集是什么?看这里:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
根据上一篇的介绍,我们用于训练的数据应该是一个一个的mini-batch,因此,对于我们即将处理的数据,最重要的就是从训练集中分出一个一个的mini-batch,看起来好像不难,但是我们要保证每次导入mini-batch中的数据都是随机的,而且所有数据在一个epoch(训练数据全部使用一次就是完成了一个epoch)只能出现一次,mini-batch中的image还要必须和label对应,而且整个过程要保证快速,对于庞大、复杂、类型多样的训练数据,这可不是已经容易的事情。所以现有的用于深度学习入门的MNIST手写字符识别这个例子中,数据导入这一块直接被忽略了,MNIST直接提供了划分mini-batch的函数给用户调用,本文不使用MNIST数据集中封装的函数进行数据导入,因为我们学习深度学习是为了处理自己的数据集构建自己的网络,那时候可没有封装好的函数可以用!但我们也不好获取那么多规范的数据,我们也选用MNIST数据集,将数据下载到本地,然后自己实现高效的划分mini-batch。
1.1 获取input_data.py文件
复制下面的代码,到Spyder中粘贴:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
按F5运行重命名为input_data.py