Tensorflow中mnist数据的封装

mnist是最最常用的一个数据集,tesnorflow中也把mnist分类做为一个入门的例子。
但是这个数据是经过封装的,所以,今天我就要把这个封装的过程弄明白
代码中读取mnist数据:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

所以首先,取github上找到对应的源码:
这里写图片描述
下面是文件的代码:

"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

晕,又调用了别的文件,一样取找到对应文件:
tensorflow/tensorflow/contrib/learn/python/learn/datasets/mnist.py
代码有点长,我就挑出几段

  validation_images = rain_images[:validation_size]
  validation_labels = rain_labels[:validation_size]
  train_images = train_images[validation_size:]
  train_labels = train_labels[validation_size:]

这里明显就是把图像分为train和test集,并且没有使用随机方式,直接就是按照前后顺序划分的。

  options = dict(dtype=dtype, reshape=reshape, seed=seed)

  train = DataSet(train_images, train_labels, **options)
  validation = DataSet(validation_images, validation_labels, **options)
  test = DataSet(test_images, test_labels, **options)

DataSet函数在就在文件中,看意思应该是将数据随机打乱,格式由二维图像转成一维数组,数据转换成float型(根据输入参数)

读取batch代码:

batch_xs, batch_ys = mnist.train.next_batch(100) 

函数原型也在文件中,看意思也是明显采用的是不放回抽样。就是有个全局变量,记录已经抽样多少个了,然后拿后面的。

print(batch_xs.shape)
print(batch_ys.shape)

输出结果:
(100, 784)
(100, 10)
意思很明白了,得到2个numpy数组。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值