前言
当 NLP 玩家遇到一个 CV 图像分类的任务时,老实的说,我是有点懵逼的。。。
任务目标是,训练一个层数较少,结构较为简单的图像分类模型,使用其最后一层隐藏层输出,作为特征向量来表征图像。
之前都是使用 Keras 较多,于是本次准备借着这个简单的任务上手 TensorFlow 2.1 。
数据加载
Python generator 出现的问题
TensorFlow 2.1 自带的 tf.data.Dataset 处理训练数据十分好用,并且自带 shuffle,repeat,和划分 batch 的方法。可以通过python generator, numpy list, Tensor slices 等数据结构直接构成 Dataset。
我训练使用的数据是文档中的插图,5个类别共 10w 张。
起初我使用的方法是:构造一个 python generator,训练时,使用 tf 自带的 tf.io.read_file() 和 tf.image.decode_jpeg() 方法从磁盘中读取数据,再使用 tf.data.Dataset.from_generator 生成数据集。
但训练时发现这样的数据处理有着很大的问题:受制于generator 的读取数据速度,batch 数据生成的速度跟不上 GPU 的训练速度