1.数据集介绍
tensorflow官方mnist:(放在database2文件夹下面)
keras官方mnist:(放在database3文件夹下面)
2.tensorflow-V1读取tensorflow-V1官方mnist:
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets('./database2/', one_hot=True)#相对路径
#tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet
print(type(mnist))#<class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>
batch = mnist.train.next_batch(100)
print(type(batch))#<class 'tuple'>
x=mnist.train.images
y=mnist.train.labels
print(type(x),x.shape)#<class 'numpy.ndarray'> (55000, 784)
print(type(y),y.shape)#<class 'numpy.ndarray'> (55000, 10)
通过查阅tensorflow的代码,如下
D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\tensorflow\examples\tutorials\mnist\input_data.py [read_data_sets]
D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py
D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py [maybe_download]@deprecated(None, 'Please write your own downloading logic.') def maybe_download(filename, work_directory, source_url): """Download the data from source url, unless it's already here. Args: filename: string, name of the file in the directory. work_directory: string, path to working directory. source_url: url to download from if file doesn't exist. Returns: Path to resulting file. """ if not gfile.Exists(work_directory): gfile.MakeDirs(work_directory) filepath = os.path.join(work_directory, filename) if not gfile.Exists(filepath): temp_file_name, _ = urlretrieve_with_retry(source_url) gfile.Copy(temp_file_name, filepath) with gfile.GFile(filepath) as f: size = f.size() print('Successfully downloaded', filename, size, 'bytes.') return filepath
可以发现read_data_sets函数在相对路径下可以直接读取已经下载好的官方mnist,如果相对路径下面没有文件可读,则会将文件下载到该相对路径下。
read_data_sets函数返回<class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>的对象,这个对象可以通过其数据成员train.images和train.labels得到numpy矩阵格式的数据。
后面训练的时候可以将这些numpy矩阵格式的数据通过feed_dict的形式送给tf.placeholder对象。
后续会再写一篇关于tensorflow训练tensorflow官方mnist数据集的过程。
3.keras(tensorflow-V2)读取keras官方mnist:
import keras
# from tensorflow import keras
def preprocess(labels, images):
'''
最简单的预处理函数:
转numpy为Tensor、分类问题需要处理label为one_hot编码、处理训练数据
'''
# 把numpy数据转为Tensor
labels = tf.cast(labels, dtype=tf.int32)
# labels 转为one_hot编码
labels = tf.one_hot(labels, depth=10)
# 顺手归一化
images = tf.cast(images, dtype=tf.float32) / 255
return labels, images
abs_path_to_dataset='H:/Leon/CODE/python_projects/master_ImRecognition/dataset/MNIST/database3/mnist.npz'
(x, y), (x_test, y_test) = keras.datasets.mnist.load_data(path=abs_path_to_dataset)#绝对路径
print(type(x),x.shape)#<class 'numpy.ndarray'> (60000, 28, 28)
print(type(y),y.shape)#<class 'numpy.ndarray'> (60000,)
db_train = tf.data.Dataset.from_tensor_slices((x, y))
print(db_train)#<DatasetV1Adapter shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>
print(type(db_train))#<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>
db_train.shuffle(1000)
db_train.map(preprocess)
db_train.batch(64)
db_train.repeat(2)
print(type(db_train))#<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>
print(db_train.output_shapes)#(TensorShape([Dimension(28), Dimension(28)]), TensorShape([]))
关于keras的加载有两种方式:
- import keras
- from tensorflow import keras
这两个模块虽然都叫做keras,但其实是完全分开的,其中
import keras的keras是独立包,不包含在tensorflow包的路径下,在“D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\keras”下面
from tensorflow import keras是tensorflow包中的子包,包含在tensorflow包的路径下,在“D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\tensorflow\python\keras”下面
当然两种keras都是用keras.datasets.mnist.load_data函数加载数据的,在path有给出且是绝对路径的时候,会在没有发现文件存在的情况下把数据下载到path中。如果没有指定path,则会下载到"C:/Users/Leon_PC/.keras/datasets"下面。
但是这两个还是有区别:
独立包keras调用load_data,会在path(绝对路径)下面找数据集,如果找不到会去s3.amazonaws.com上面下载,这个网址不用科学上网可以登陆。
D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\keras\datasets\mnist.py [load_data]
def load_data(path='mnist.npz'): """Loads the MNIST dataset. # Arguments path: path where to cache the dataset locally (relative to ~/.keras/datasets). # Returns Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ path = get_file(path, origin='https://s3.amazonaws.com/img-datasets/mnist.npz', file_hash='8a61469f7ea1b51cbae51d4f78837e45') f = np.load(path) x_train, y_train = f['x_train'], f['y_train'] x_test, y_test = f['x_test'], f['y_test'] f.close() return (x_train, y_train), (x_test, y_test)
tensorflow的子包keras,也会在path(绝对路径)下面找数据集,但是如果找不到会去storage.googleapis.com上面下载,这个网址一定得科学上网才可以登陆。
D:\Users\Leon_PC\Anaconda3\envs\tensorflow1_13_1\Lib\site-packages\tensorflow\python\keras\datasets\mnist.py [load_data]
@tf_export('keras.datasets.mnist.load_data') def load_data(path='mnist.npz'): """Loads the MNIST dataset. Arguments: path: path where to cache the dataset locally (relative to ~/.keras/datasets). Returns: Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. License: Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset, which is a derivative work from original NIST datasets. MNIST dataset is made available under the terms of the [Creative Commons Attribution-Share Alike 3.0 license.]( https://creativecommons.org/licenses/by-sa/3.0/) """ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' path = get_file( path, origin=origin_folder + 'mnist.npz', file_hash='8a61469f7ea1b51cbae51d4f78837e45') with np.load(path) as f: x_train, y_train = f['x_train'], f['y_train'] x_test, y_test = f['x_test'], f['y_test'] return (x_train, y_train), (x_test, y_test)
keras的load_data函数返回numpy矩阵格式的数据,结合tf.data.Dataset.from_tensor_slices,可以得到<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>对象,
可以更加方便地用其成员方法shuffle打乱数据,map数据映射操作(可以在preprocess函数中对图像images和标签labels做一些预处理),batch设置batchsize值,repeat设置数据集重复次数
db_train = tf.data.Dataset.from_tensor_slices((x, y))
print(db_train)#<DatasetV1Adapter shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>
print(type(db_train))#<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>
db_train.shuffle(1000)
db_train.map(preprocess)
db_train.batch(64)
db_train.repeat(2)
print(type(db_train))#<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>
print(db_train.output_shapes)#(TensorShape([Dimension(28), Dimension(28)]), TensorShape([]))
def preprocess(labels, images):
'''
最简单的预处理函数:
转numpy为Tensor、分类问题需要处理label为one_hot编码、处理训练数据
'''
# 把numpy数据转为Tensor
labels = tf.cast(labels, dtype=tf.int32)
# labels 转为one_hot编码
labels = tf.one_hot(labels, depth=10)
# 顺手归一化
images = tf.cast(images, dtype=tf.float32) / 255
return labels, images
后面训练的时候可以将这些numpy矩阵格式的数据通过feed_dict的形式送给tf.placeholder对象。
后续会再写一篇关于tensorflow训练keras官方mnist数据集的过程,关于tensorflow2.x版本的训练方式
参考:
https://www.cnblogs.com/heze/p/12076792.html
https://blog.csdn.net/i_love_zxy/article/details/103543108