tensorflow 8. MNIST的脚手架--input_data

65 篇文章 5 订阅
34 篇文章 4 订阅

在之前的案例中经常见到这样使用MNIST数据集的用法:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)

这句话的作用是,如果”/tmp/data/”目录下存在mnist数据集,则加载,否则先下载后加载。

input_data的源码在tensorflow的相应目录下,其内容为:

"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=unused-import
import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

# 最后一句尤为重要
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# pylint: enable=unused-import

最后一句才是需要的内容。这一句来自tensorflow.contrib.learn.python.learn.datasets.mnist。

read_data_sets代码如下:

def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000,
                   seed=None,
                   source_url=DEFAULT_SOURCE_URL):
  if fake_data:

    def fake():
      return DataSet(
          [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)

    train = fake()
    validation = fake()
    test = fake()
    return base.Datasets(train=train, validation=validation, test=test)

  if not source_url:  # empty string check
    source_url = DEFAULT_SOURCE_URL

  # 数据集主要分为训练和测试两部分
  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

  # 指定目录下没有则会从网上下载数据集
  local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                   source_url + TRAIN_IMAGES)
  with gfile.Open(local_file, 'rb') as f:
    train_images = extract_images(f)

  local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                   source_url + TRAIN_LABELS)
  with gfile.Open(local_file, 'rb') as f:
    train_labels = extract_labels(f, one_hot=one_hot)

  local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                   source_url + TEST_IMAGES)
  with gfile.Open(local_file, 'rb') as f:
    test_images = extract_images(f)

  local_file = base.maybe_download(TEST_LABELS, train_dir,
                                   source_url + TEST_LABELS)
  with gfile.Open(local_file, 'rb') as f:
    test_labels = extract_labels(f, one_hot=one_hot)

  if not 0 <= validation_size <= len(train_images):
    raise ValueError('Validation size should be between 0 and {}. Received: {}.'
                     .format(len(train_images), validation_size))

  # 默认情况下,5000个作为validation,其余的作为训练数据
  validation_images = train_images[:validation_size]
  validation_labels = train_labels[:validation_size]
  train_images = train_images[validation_size:]
  train_labels = train_labels[validation_size:]

  options = dict(dtype=dtype, reshape=reshape, seed=seed)

  train = DataSet(train_images, train_labels, **options)
  validation = DataSet(validation_images, validation_labels, **options)
  test = DataSet(test_images, test_labels, **options)

  # 最终数据集被分为3部分:训练、评估、测试
  return base.Datasets(train=train, validation=validation, test=test)

上面代码的意图比较简单,就是下载数据,最终数据集被分为3部分:训练、评估、测试,然后封装成一个tuple类型。使用方法分别为:mnist.train、mnist.validation、mnist.test。

这个三个子数据集的成员方法类似,其定义在tensorflow\tensorflow\contrib\learn\python\learn\datasets\mnist.py,有以下成员方法:

  # 获取所有图片
  @property
  def images(self):
    return self._images

  # 获取所有标注 
  @property
  def labels(self):
    return self._labels

  # 获取数据集大小
  @property
  def num_examples(self):
    return self._num_examples

  # 记录用了多少轮了
  @property
  def epochs_completed(self):
    return self._epochs_completed

  # 获取一批数据,大小由batch_size确定
  def next_batch(self, batch_size, fake_data=False, shuffle=True):

主要方法就是mnist.train.num_examples和mnist.train.next_batch(XXXX)。前者用于获取数据集大小,后者用于获取一个batch的数据。

到这里,基本介绍完了mnist的input_data模块。

使用例子:

Xtr, Ytr = mnist.train.next_batch(5000) # 5000个作为训练使用(nn candidates)
Xte, Yte = mnist.test.next_batch(200) #200个用于测试
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值