在看过这篇文章的人中,似乎没有人给出最终答案,因此我想给出对我有用的答案.由于该领域缺乏文档,我的答案可能缺少一些相关细节.请随时添加我在这里没有提及的更多信息.
似乎在Windows中不支持用Python编写一个继承Sequence类的生成器类. (您似乎可以使其在Linux上运行.)要使其运行,您需要设置参数use_multiprocessing = True(使用类方法).但是如上所述,它在Windows上不起作用,因此您必须将use_multiprocessing设置为False(在Windows上).
但是,这并不意味着多重处理在Windows上不起作用.即使将use_multiprocessing = False设置为false,使用以下设置运行代码时仍可以支持多处理,只需将worker参数设置为任何大于1的值即可.
例:
history = \n merged_model.fit_generator(generator=train_generator,
steps_per_epoch=trainset_steps_per_epoch,
epochs=300,
verbose=1,
use_multiprocessing=False,
workers=3,
max_queue_size=4)
此时,让我们再次记住Keras文档:
The use of keras.utils.Sequence guarantees the ordering and guarantees
the single use of every input per epoch when using
use_multiprocessing=True.
据我了解,如果use_multiprocessing = False,则生成器不再是线程安全的,这使得编写继承Sequence的生成器类变得困难.
为了解决这个问题,我自己编写了一个生成器,该生成器手动使线程安全.这是一个伪代码示例:
import tensorflow as tf
import threading
class threadsafe_iter:
"""Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self): # Py3
return next(self.it)
#def next(self): # Python2 only
# with self.lock:
# return self.it.next()
def threadsafe_generator(f):
"""A decorator that takes a generator function and makes it thread-safe.
"""
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g
@threadsafe_generator
def generate_data(tfrecord_file_path_list, ...):
dataset = tf.data.TFRecordDataset(tfrecord_file_path_list)
# example proto decode
def _parse_function(example_proto):
...
return batch_data
# Parse the record into tensors.
dataset = dataset.map(_parse_function)
dataset = dataset.shuffle(buffer_size=100000)
# Repeat the input indefinitly
dataset = dataset.repeat()
# Generate batches
dataset = dataset.batch(batch_size)
# Create an initializable iterator
iterator = dataset.make_initializable_iterator()
# Get batch data
batch_data = iterator.get_next()
iterator_init_op = iterator.make_initializer(dataset)
with tf.Session() as sess:
sess.run(iterator_init_op)
while True:
try:
batch_data = sess.run(batch_data)
except tf.errors.OutOfRangeError:
break
yield batch_data
好吧,可以这样进行讨论是否真的很优雅,但似乎运行得很好.
总结一下:
>如果在Windows上编写程序,请将use_multiprocessing设置为False.
>(据我所知,到今天为止)在Windows上编写代码时,不支持编写一个继承Sequence的生成器类. (我猜这是一个Tensorflow / Keras问题).
>要解决此问题,请编写一个普通的生成器,使生成器线程安全,并将worker设置为大于1的数字.
重要说明:在此设置中,生成器在CPU上运行,而训练在GPU上完成.我可以观察到的一个问题是,如果您正在训练的模型足够浅,则GPU的利用率仍然很低,而CPU利用率却很高.如果模型较浅并且数据集足够小,那么将所有数据存储在内存中并在GPU上运行所有数据是一个不错的选择.它应该大大加快培训速度.如果出于任何原因想要同时使用CPU和GPU,我的建议是尝试使用Tensorflow的tf.data API,该API可显着加快数据预处理和批处理的速度.如果生成器仅使用Python编写,则GPU会一直等待数据以继续训练.可以说有关Tensorflow / Keras文档的所有内容,但这确实是高效的代码!
如果您对API有更全面的了解,并且看到这篇文章,请随时在这里纠正我,以防万一我误解了任何东西,或者更新了API以解决问题,甚至在Windows上也是如此.