在 TensorFlow 中,keras.datasets 模块提供了常用经典数据集的自动下载、管理、加载与转换功能,并且提供了 tf.data.Dataset 数据集对象,方便实现多线程(Multi-thread),预处理(Preprocess),随机打散(Shuffle)和批训练(Train on batch)等常用数据集功能。
常用的数据集,如:
⚪ Boston Housing 波士顿房价趋势数据集,用于回归模型训练与测试
⚪ CIFAR10/100 真实图片数据集,用于图片分类任务
⚪ MNIST/Fashion_MNIST 手写数字图片数据集,用于图片分类任务
⚪ IMDB 情感分类任务数据集
通过datasets.xxx.laod_data()即可实现经典数据集的自动加载,其中xxx代表具体的数据集名称。TensorFlow 会默认将数据缓存在用户目录下的.keras/datasets 文件夹,如下图所示。
用户不需要关心数据集是如何保存的。如果当前数据集不在缓存中,则会自动从网站下载和解压,加载;如果已经在缓存中,自动完成加载:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets # 导入经典数据集加载模块
# 加载 MNIST 数据集
(x, y), (x_test, y_test) = datasets.mnist.load_data()
print('x:', x.shape, 'y:', y.shape, 'x test:', x_test.shape, 'y test:', y_test)
x: (60000, 28, 28) y: (60000,) x test: (10000, 28, 28) y test: [7 2 1 ... 4 5 6]