tensorflow.keras.utils.Sequence的使用(控制模型从文件读入batch_size的数据
在使用keras的时候,一般使用model.fit()来传入训练数据,fit()接受多种类型的数据:
1.数组类型(如numpy等)。注意,tensorflow2以后的版本在接受h5py类型数据时,容易出错,原因我也不是特别懂
2.dataset类型
3.python generator,但是限制比较多,一般要在编写python generator的平 台下运行模型
4.tensorflow.keras.utils.Sequence,和python generator差不多,但是限制较少,可迁移性更好
第4种类型是本文要讲的重点类型
官方例子
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x)</