生成迭代器,配合fit_generator进行大规模数据训练,减少因为fit方式直接把所有数据输入内存造成内存溢出问题
要保持和fit一样的训练效果,要重写Keras相应的类,主要分为三种类
- sequence 时序 TimeseriesGenerator(tf.keras.preprocessing.sequence.TimeseriesGenerator)
- txt 文本 text_to_word_sequence ()
- 图像 image ImageDataGenerator (tf.keras.preprocessing.image.ImageDataGenerator)
以时间序列为例
核心是重写几个方法
def __len__(self):
def __getitem__(self, index):
from keras.utils import data_utils
# data_utils.Sequence
from tensorflow.keras.preprocessing.sequence import TimeseriesGenerator
class CustomGenerator(TimeseriesGenerator):
"""
class TimeseriesGenerator
https://github.com/keras-team/keras/blob/v2.9.0/keras/preprocessing/sequence.py#L55-L231
"""
def __init__(self, pre_len, kd_flag, **kwargs):
super(CustomGenerator,self).__init__(**kwargs)
self.kd_flag = kd_flag
self.pre_len = pre_len
def __getitem__(self, index):
if self.shuffle:
rows = np.random.randint(
self.start_index, self.end_index + 1, size=self.batch_size)
# print(rows)
else:
i = self.start_index + self.batch_size * index
rows = np.arange(
i, min(i + self.batch_size, self.end_index + 1))
# print(rows)
np.random.shuffle(self.data)
samples = self.data[rows]
if self.kd_flag:
return samples, [samples[:, -1, -1], samples[:, -pre_len:, 1], np.zeros(inputs_new.shape[0])]
else:
return samples, [samples[:, -1, -1], samples[:, -pre_len:, 1]]
完成函数重写之后,便可以使用fit_generator进行训练了
例子
gen_input_kd = CustomGenerator()
example = example_model()
example.compile(
optimizer=opt,#keras.optimizers.Adam(learning_rate=0.001, amsgrad=True),
metrics=['mae'],#tf.keras.metrics.get('mae'),
loss = [loss1,loss2,loss3]
,loss_weights=[0.75,0.05,0.2]
)
example.fit_generator(gen_input_kd, epochs=epochs, verbose=2)
原是函数的方法
keras/sequence.py at 07e13740fd181fc3ddec7d9a594d8a08666645f6 · keras-team/keras · GitHub