Tensorflow2.0之经典数据集加载
常用数据集
在 TensorFlow 中,keras.datasets 模块提供了常用经典数据集的自动下载、管理、加载与转换功能,并且提供了tf.data.Dataset 数据集对象,方便实现多线程(Multi-thread),预处理(Preprocess),随机打散(Shuffle)和批训练(Train on batch)等常用数据集功能。
对于常用的数据集,如:
❑ Boston Housing 波士顿房价趋势数据集,用于回归模型训练与测试
❑ CIFAR10/100 真实图片数据集,用于图片分类任务
❑ MNIST/Fashion_MNIST 手写数字图片数据集,用于图片分类任务
❑ IMDB 情感分类任务数据集
这些数据集在机器学习、深度学习的研究和学习中使用的非常频繁。对于新提出的算法,一般优先在简单的数据集上面测试,再尝试迁移到更大规模、更复杂的数据集上。通过 datasets.xxx.load_data()即可实现经典数据集的自动加载,其中xxx 代表具体的数据集名称。TensorFlow 会默认将数据缓存在用户目录下的.keras/datasets 文件夹,如图 5.6所示,用户不需要关心数据集是如何保存的。如果当前数据集不在缓存中,则会自动从网站下载和解压,加载;如果已经在缓存中,自动完成加载:
In [66]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets # 导入经典数据集加载模块
# 加载MNIST 数据集
(x, y), (x_test, y_test) = datasets.mnist.load_data()
print('x:', x.shape, 'y:', y.shape, 'x test:', x_test.shape, 'y test:',
y_test)
Out [66]:
x: (60000, 28, 28) y: (60000,) x test: (10000, 28, 28) y test: [7 2 1 ... 4
5 6]
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
通过load_data()会返回相应格式的数据,对于图片数据集MNIST, CIFAR10 等,会返回2个tuple,第一个tuple 保存了用于训练的数据x,y 训练集对象;第2 个tuple 则保存了用于测试的数据x_test,y_test 测试集对象,所有的数据都用Numpy.array 容器承载。
数据加载进入内存后,需要转换成Dataset 对象,以利用TensorFlow 提供的各种便捷功能。通过Dataset.from_tensor_slices 可以将训练部分的数据图片x 和标签y 都转换成Dataset 对象:
train_db = tf.data.Dataset.from_tensor_slices((x, y))
将数据转换成Dataset 对象后,一般需要再添加一系列的数据集标准处理步骤,如随机打散,预处理,按批装载等
随机打散
通过 Dataset.shuffle(buffer_size)工具可以设置Dataset 对象随机打散数据之间的顺序,防止
每次训练时数据按固定顺序产生,从而使得模型尝试“记忆”住标签信息:
其中buffer_size 指定缓冲池的大小,一般设置为一个较大的参数即可。通过Dataset 提供的这些工具函数会返回新的Dataset 对象,可以通过
db = db. shuffle(). step2(). step3. ()
方式完成所有的数据处理步骤,实现起来非常方便。
批训练
为了利用显卡的并行计算能力,一般在网络的计算过程中会同时计算多个样本,我们把这种训练方式叫做批训练,其中样本的数量叫做batch size。为了一次能够从Dataset 中产生batch size 数量的样本,需要设置Dataset 为批训练方式:
train_db = train_db.batch(128)
其中128 为batch size 参数,即一次并行计算128 个样本的数据。Batch size 一般根据用户的GPU 显存资源来设置,当显存不足时,可以适量减少batch size 来减少算法的显存使用量。
预处理
从 keras.datasets 中加载的数据集的格式大部分情况都不能满足模型的输入要求,因此
需要根据用户的逻辑自己实现预处理函数。Dataset 对象通过提供map(func)工具函数可以
非常方便地调用用户自定义的预处理逻辑,它实现在func 函数里:
预处理函数实现在preprocess 函数中,传入函数引用即可
train_db = train_db.map(preprocess)
考虑 MNIST 手写数字图片,从keras.datasets 中经.batch()后加载的图片x shape 为[𝑏, 28,28],像素使用0~255 的整形表示;标注shape 为[𝑏],即采样的数字编码方式。实际的神经网络输入,一般需要将图片数据标准化到[0,1]或[−1,1]等0 附近区间,同时根据网络的设置,需要将shape [28,28] 的输入Reshape 为合法的格式;对于标注信息,可以选择在预处理时进行one-hot 编码,也可以在计算误差时进行one-hot 编码。
我们可以将MNIST 图片数据映射到𝑥 ∈ [0,1]区间,视图调整为[𝑏, 28 ∗ 28];对于标注y,我们选择在预处理函数里面进行one-hot 编码:
def preprocess(x, y): # 自定义的预处理函数
# 调用此函数时会自动传入x,y 对象,shape 为[b, 28, 28], [b]
# 标准化到0~1
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [-1, 28*28]) # 打平
y = tf.cast(y, dtype=tf.int32) # 转成整形张量
y = tf.one_hot(y, depth=10) # one-hot 编码
# 返回的x,y 将替换传入的x,y 参数,从而实现数据的预处理功能
return x,y
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
循环训练
对于 Dataset 对象,在使用时可以通过
for step, (x,y) in enumerate(train_db): # 迭代数据集对象,带step 参数
或
for x,y in train_db: # 迭代数据集对象
方式进行迭代,每次返回的x,y 对象即为批量样本和标签,当对train_db 的所有样本完成一次迭代后,for 循环终止退出。我们一般把完成一个batch 的数据训练,叫做一个step;通过多个step 来完成整个训练集的一次迭代,叫做一个epoch。在实际训练时,通常需要对数据集迭代多个epoch 才能取得较好地训练效果:
for epoch in range(20): # 训练Epoch 数
for step, (x,y) in enumerate(train_db): # 迭代Step 数
# training...
- 1
- 2
- 3
此外,也可以通过设置:
train_db = train_db.repeat(20) # 数据集跌打20 遍才终止使得for x,y in train_db 循环迭代20 个epoch 才会退出。不管使用上述哪种方式,都能取得一样的效果。