结合序列,使用multi_processing = False和workers =例如. 4确实有效.
我刚刚意识到,在问题的示例代码中,我没有看到加速,因为数据生成太快.通过插入time.sleep(2),这变得很明显.
class DummySequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
time.sleep(2)
return np.array(batch_x), np.array(batch_y)
x = np.random.random((100, 3))
y = to_categorical(np.random.random(100) > .5).astype(int)
seq = DummySequence(x, y, 10)
model = Sequential()
model.add(Dense(32, input_dim=3))
model.add(Dense(2, activat