在 TensorFlow 中,keras.datasets
模块提供了常用经典数据集的自动下载、管理、加载与转换功能,并且提供了tf.data.Dataset
数据集对象,方便实现多线程(Multi-thread),预处理(Preprocess),随机打散(Shuffle)和批训练(Train on batch)等常用数据集功能。
常用的数据集:
(1)Boston Housing 波士顿房价趋势数据集,用于回归模型训练与测试
(2)CIFAR10/100 真实图片数据集,用于图片分类任务
(3)MNIST/Fashion_MNIST 手写数字图片数据集,用于图片分类任务
(4)IMDB 情感分类任务数据集
这些数据集在机器学习、深度学习的研究和学习中使用的非常频繁。对于新提出的算法,
一般优先在简单的数据集上面测试,再尝试迁移到更大规模、更复杂的数据集上。
通过 datasets.xxx.load_data()
即可实现经典数据集的自动加载,其中xxx 代表具体的数据集名称。
TensorFlow 会默认将数据缓存在用户目录下的.keras/datasets 文件夹,如图 所示,用户不需要关心数据集是如何保存的。如果当前数据集不在缓存中,则会自动从网
站下载和解压,加载;如果已经在缓存中,自动完成加载:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets # 导入经典数据集加载模块
# 加载MNIST 数据集
(x, y), (x_test, y_test) = datasets.mnist.load_data()
通过load_data()
会返回相应格式的数据,对于图片数据集MNIST, CIFAR10 等,会返回2个tuple,第一个tuple 保存了用于训练的数据x,y 训练集对象;第2 个tuple 则保存了用于
测试的数据x_test,y_test 测试集对象,所有的数据都用Numpy.array 容器承载。
数据加载进入内存后,需要转换成 Dataset 对象,以利用TensorFlow 提供的各种便捷
功能。通过Dataset.from_tensor_slices
可以将训练部分的数据图片x 和标签y 都转换成Dataset 对象:
train_db = tf.data.Dataset.from_tensor_slices((x, y))
将数据转换成 Dataset 对象后,一般需要再添加一系列的数据集标准处理步骤,如随机打
散,预处理,按批装载等。