如何使用fit_generator和keras.utils.Sequence来解决mutiple output的数据生成问题
有用的几个网页
- https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
- https://github.com/keras-team/keras/issues/8130
为什么要使用Sequence
Sequence are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.
也就是说generator并不能保证每个sample在一个epoch只训练一次,减少过拟合的产生。
使用Sequence时需要注意的子函数
Every Sequence must implement the __getitem__ and the __len__ methods. If you want to modify your dataset between epochs you may implement on_epoch_end. The method __getitem__ should return a complete batch.
__getitem__ 需要生成一个完整的batch用于训练,__len__ 通常是用训练集中所有的sample数量除以batch size,作为fit_generator中的steps_per_epoch参数。
需要注意的的点
- __getitem__ 函数如果返回多个mask,使用如下方式返回
def __getitem__:
img = # your image
gt_1 = # your ground truth 1
gt_2 = # your ground truth 2
gt_3 = # your ground truth 3
xxxxxxxx
return img, [gt_1, gt_2, gt_3]
- 上面代码中img,gt_1,gt_2,gt_3都必须是np.array格式