为什么需要自定义数据生成器?
在实际的深度学习项目中,我们经常会遇到这样的情况:训练数据集大到内存根本装不下,传统的一次性加载方式直接导致程序崩溃;或者数据需要进行复杂的实时增强,普通的Dataset API搞不定;再或者数据源根本不是标准图片文件,而是需要动态解析的特殊格式…
这时候就需要祭出我们的杀手锏——自定义数据生成器(Custom Data Generator)。这个神器能帮我们做到:
- 内存友好:像流水线一样按需加载数据
- 灵活扩展:自由添加各种预处理逻辑
- 性能优化:充分利用多核CPU实现并行加载
今天我就带大家从零开始,手撸一个工业级的自定义数据生成器。本文包含20+个关键代码片段,最后会给出完整可运行的实例代码。准备好了吗?咱们发车!
一、数据生成器核心原理
1.1 生成器工作流程
数据生成器的本质是一个迭代器,在训练过程中:
- 主线程负责模型计算(GPU)
- 后台线程预加载下一批数据(CPU)
- 通过队列实现计算与加载的流水线
1.2 必须实现的接口
继承tf.keras.utils.Sequence
类时,必须实现三个核心方法:
class CustomGenerator(tf.keras.utils.Sequence):
def __len__(self):
"""返回总批次数"""
def __getitem__(self, index):
"""返回第index批的数据"""
def on_epoch_end(self):
"""每个epoch结束时调用"""
二、实战:医疗影像生成器
假设我们要处理10万张256x256的CT扫描图片,每张图片对应一个包含5种疾病的标签。由于数据量太大,无法一次性加载到内存。
2.1 项目结构
medical_ct/
├── data/
│ ├── train/
│ │ ├── patient_001/
│ │ │ ├── slice_001.png
│ │ │ └── labels.csv
│ │ └── ...
│ └── val/
└── generators/
└── ct_generator.py
三、分步实现生成器
3.1 初始化参数
class CTScanGenerator(tf.keras.utils.Sequence):
def __init__(self, data_dir, batch_size=32, shuffle=True, augment=False):
self.data_dir = data_dir
self.batch_size = batch_size
self.shuffle = shuffle
self.augment = augment
# 获取所有患者目录
self.patient_dirs = sorted([
os.path.join(data_dir, d)
for d in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, d))
])
# 提前计算总样本数
self.total_slices = 0
for p_dir in self.patient_dirs:
num_slices = len([f for f in os.listdir(p_dir) if f.endswith('.png')])
self.total_slices += num_slices
self.indices = np.arange(self.total_slices)
self.on_epoch_end()
3.2 实现__len__方法
def __len__(self):
"""计算总批次数"""
return int(np.ceil(self.total_slices / self