如果训练网络时,针对的是大规模数据集,如图像数据集,其不能完全读取加载到内存里,那么就需要利用用到 data generator了. data generator 将数据分为 batches,再送入网络进行训练.
TnesorFlow 有对应的 API,但其 API 比较复杂,且容易出错.
对于习惯于 Keras 的人来说,Keras 减少了学习繁琐 API 的成本(未来可能会发生变化),只需关注于模型设计.
另一个采用 Keras 的 Sequence class 作为 batch data generator 的优势在于,Keras 能够处理所有的 multi-threading 和并行化(parallelization),以确保训练过程中不需要 batch data generation 的数据等待. 其背后的原理是,采用了 multiplt CPUs 核提前拉取了 batches 数据.
1. 采用 Keras 的 Sequence Class
1.1 Keras 的 Sequence 文档keras.utils.Sequence()
每个 Sequence 必须包含 __getitem__ 和 __len__ 方法的实现.
如果需要修改自定义数据集的 epochs 间的数据,则需要实现 on_epoch_end.
__getitem__返回一个完整的 batch.
Sequence 是进行 Multiprocessing 的更加安全的方式. 其保证了网络只对每个 epoch 内的每个样本训练一次.
例示:from skimage.io import imread
from skimage.transform import resize
import numpy as np