【Tensorflow】tensorflow和keras读取官方版本的MNIST数据集

1.数据集介绍

tensorflow官方mnist:(放在database2文件夹下面)

 

keras官方mnist:(放在database3文件夹下面)

 

 

2.tensorflow-V1读取tensorflow-V1官方mnist:

import tensorflow as tf 
import tensorflow.examples.tutorials.mnist.input_data as input_data

mnist = input_data.read_data_sets('./database2/', one_hot=True)#相对路径
#tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet
print(type(mnist))#<class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>

batch = mnist.train.next_batch(100)
print(type(batch))#<class 'tuple'>

x=mnist.train.images
y=mnist.train.labels
print(type(x),x.shape)#<class 'numpy.ndarray'> (55000, 784)
print(type(y),y.shape)#<class 'numpy.ndarray'> (55000, 10)

通过查阅tensorflow的代码,如下 

D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\tensorflow\examples\tutorials\mnist\input_data.py [read_data_sets]
D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py
D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py [maybe_download]

@deprecated(None, 'Please write your own downloading logic.')
def maybe_download(filename, work_directory, source_url):
  """Download the data from source url, unless it's already here.

  Args:
      filename: string, name of the file in the directory.
      work_directory: string, path to working directory.
      source_url: url to download from if file doesn't exist.

  Returns:
      Path to resulting file.
  """
  if not gfile.Exists(work_directory):
    gfile.MakeDirs(work_directory)
  filepath = os.path.join(work_directory, filename)
  if not gfile.Exists(filepath):
    temp_file_name, _ = urlretrieve_with_retry(source_url)
    gfile.Copy(temp_file_name, filepath)
    with gfile.GFile(filepath) as f:
      size = f.size()
    print('Successfully downloaded', filename, size, 'bytes.')
  return filepath

 

可以发现read_data_sets函数在相对路径下可以直接读取已经下载好的官方mnist,如果相对路径下面没有文件可读,则会将文件下载到该相对路径下。

read_data_sets函数返回<class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>的对象,这个对象可以通过其数据成员train.images和train.labels得到numpy矩阵格式的数据。

后面训练的时候可以将这些numpy矩阵格式的数据通过feed_dict的形式送给tf.placeholder对象。

后续会再写一篇关于tensorflow训练tensorflow官方mnist数据集的过程。

 

 

3.keras(tensorflow-V2)读取keras官方mnist:

import keras
# from tensorflow import keras

def preprocess(labels, images):
	'''
	最简单的预处理函数:
		转numpy为Tensor、分类问题需要处理label为one_hot编码、处理训练数据
	'''
	# 把numpy数据转为Tensor
	labels = tf.cast(labels, dtype=tf.int32)
	# labels 转为one_hot编码
	labels = tf.one_hot(labels, depth=10)
	# 顺手归一化
	images = tf.cast(images, dtype=tf.float32) / 255
	return labels, images

abs_path_to_dataset='H:/Leon/CODE/python_projects/master_ImRecognition/dataset/MNIST/database3/mnist.npz'
(x, y), (x_test, y_test) = keras.datasets.mnist.load_data(path=abs_path_to_dataset)#绝对路径
print(type(x),x.shape)#<class 'numpy.ndarray'> (60000, 28, 28)
print(type(y),y.shape)#<class 'numpy.ndarray'> (60000,)
db_train = tf.data.Dataset.from_tensor_slices((x, y))
print(db_train)#<DatasetV1Adapter shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>
print(type(db_train))#<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>
db_train.shuffle(1000)
db_train.map(preprocess)
db_train.batch(64)
db_train.repeat(2)
print(type(db_train))#<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>
print(db_train.output_shapes)#(TensorShape([Dimension(28), Dimension(28)]), TensorShape([]))

 关于keras的加载有两种方式:

  • import keras
  • from tensorflow import keras

这两个模块虽然都叫做keras,但其实是完全分开的,其中

import keras的keras是独立包,不包含在tensorflow包的路径下,在“D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\keras”下面

from tensorflow import keras是tensorflow包中的子包,包含在tensorflow包的路径下,在“D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\tensorflow\python\keras”下面

当然两种keras都是用keras.datasets.mnist.load_data函数加载数据的,在path有给出且是绝对路径的时候,会在没有发现文件存在的情况下把数据下载到path中。如果没有指定path,则会下载到"C:/Users/Leon_PC/.keras/datasets"下面。

但是这两个还是有区别:

独立包keras调用load_data,会在path(绝对路径)下面找数据集,如果找不到会去s3.amazonaws.com上面下载,这个网址不用科学上网可以登陆。

 D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\keras\datasets\mnist.py [load_data]

def load_data(path='mnist.npz'):
    """Loads the MNIST dataset.

    # Arguments
        path: path where to cache the dataset locally
            (relative to ~/.keras/datasets).

    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    path = get_file(path,
                    origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
                    file_hash='8a61469f7ea1b51cbae51d4f78837e45')
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    return (x_train, y_train), (x_test, y_test)

 

tensorflow的子包keras,也会在path(绝对路径)下面找数据集,但是如果找不到会去storage.googleapis.com上面下载,这个网址一定得科学上网才可以登陆。

 D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\tensorflow\python\keras\datasets\mnist.py [load_data]

@tf_export('keras.datasets.mnist.load_data')
def load_data(path='mnist.npz'):
  """Loads the MNIST dataset.

  Arguments:
      path: path where to cache the dataset locally
          (relative to ~/.keras/datasets).

  Returns:
      Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.

  License:
      Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
      which is a derivative work from original NIST datasets.
      MNIST dataset is made available under the terms of the
      [Creative Commons Attribution-Share Alike 3.0 license.](
      https://creativecommons.org/licenses/by-sa/3.0/)
  """
  origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
  path = get_file(
      path,
      origin=origin_folder + 'mnist.npz',
      file_hash='8a61469f7ea1b51cbae51d4f78837e45')
  with np.load(path) as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']

    return (x_train, y_train), (x_test, y_test)

 

keras的load_data函数返回numpy矩阵格式的数据,结合tf.data.Dataset.from_tensor_slices,可以得到<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>对象,

可以更加方便地用其成员方法shuffle打乱数据,map数据映射操作(可以在preprocess函数中对图像images和标签labels做一些预处理),batch设置batchsize值,repeat设置数据集重复次数

db_train = tf.data.Dataset.from_tensor_slices((x, y))
print(db_train)#<DatasetV1Adapter shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>
print(type(db_train))#<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>
db_train.shuffle(1000)
db_train.map(preprocess)
db_train.batch(64)
db_train.repeat(2)
print(type(db_train))#<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>
print(db_train.output_shapes)#(TensorShape([Dimension(28), Dimension(28)]), TensorShape([]))

def preprocess(labels, images):
	'''
	最简单的预处理函数:
		转numpy为Tensor、分类问题需要处理label为one_hot编码、处理训练数据
	'''
	# 把numpy数据转为Tensor
	labels = tf.cast(labels, dtype=tf.int32)
	# labels 转为one_hot编码
	labels = tf.one_hot(labels, depth=10)
	# 顺手归一化
	images = tf.cast(images, dtype=tf.float32) / 255
	return labels, images

后面训练的时候可以将这些numpy矩阵格式的数据通过feed_dict的形式送给tf.placeholder对象。

后续会再写一篇关于tensorflow训练keras官方mnist数据集的过程,关于tensorflow2.x版本的训练方式

参考:

https://www.cnblogs.com/heze/p/12076792.html

https://blog.csdn.net/i_love_zxy/article/details/103543108

 

 

 

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值