mnist是最最常用的一个数据集,tesnorflow中也把mnist分类做为一个入门的例子。
但是这个数据是经过封装的,所以,今天我就要把这个封装的过程弄明白
代码中读取mnist数据:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
所以首先,取github上找到对应的源码:
下面是文件的代码:
"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
晕,又调用了别的文件,一样取找到对应文件:
tensorflow/tensorflow/contrib/learn/python/learn/datasets/mnist.py
代码有点长,我就挑出几段
validation_images = rain_images[:validation_size]
validation_labels = rain_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
这里明显就是把图像分为train和test集,并且没有使用随机方式,直接就是按照前后顺序划分的。
options = dict(dtype=dtype, reshape=reshape, seed=seed)
train = DataSet(train_images, train_labels, **options)
validation = DataSet(validation_images, validation_labels, **options)
test = DataSet(test_images, test_labels, **options)
DataSet函数在就在文件中,看意思应该是将数据随机打乱,格式由二维图像转成一维数组,数据转换成float型(根据输入参数)
读取batch代码:
batch_xs, batch_ys = mnist.train.next_batch(100)
函数原型也在文件中,看意思也是明显采用的是不放回抽样。就是有个全局变量,记录已经抽样多少个了,然后拿后面的。
print(batch_xs.shape)
print(batch_ys.shape)
输出结果:
(100, 784)
(100, 10)
意思很明白了,得到2个numpy数组。