python keras_python-在Keras / Tensorflow中类生成器(继承序列...

在看过这篇文章的人中,似乎没有人给出最终答案,因此我想给出对我有用的答案.由于该领域缺乏文档,我的答案可能缺少一些相关细节.请随时添加我在这里没有提及的更多信息.

似乎在Windows中不支持用Python编写一个继承Sequence类的生成器类. (您似乎可以使其在Linux上运行.)要使其运行,您需要设置参数use_multiprocessing = True(使用类方法).但是如上所述,它在Windows上不起作用,因此您必须将use_multiprocessing设置为False(在Windows上).

但是,这并不意味着多重处理在Windows上不起作用.即使将use_multiprocessing = False设置为false,使用以下设置运行代码时仍可以支持多处理,只需将worker参数设置为任何大于1的值即可.

例:

history = \n merged_model.fit_generator(generator=train_generator,

steps_per_epoch=trainset_steps_per_epoch,

epochs=300,

verbose=1,

use_multiprocessing=False,

workers=3,

max_queue_size=4)

此时,让我们再次记住Keras文档:

The use of keras.utils.Sequence guarantees the ordering and guarantees

the single use of every input per epoch when using

use_multiprocessing=True.

据我了解,如果use_multiprocessing = False,则生成器不再是线程安全的,这使得编写继承Sequence的生成器类变得困难.

为了解决这个问题,我自己编写了一个生成器,该生成器手动使线程安全.这是一个伪代码示例:

import tensorflow as tf

import threading

class threadsafe_iter:

"""Takes an iterator/generator and makes it thread-safe by

serializing call to the `next` method of given iterator/generator.

"""

def __init__(self, it):

self.it = it

self.lock = threading.Lock()

def __iter__(self):

return self

def __next__(self): # Py3

return next(self.it)

#def next(self): # Python2 only

# with self.lock:

# return self.it.next()

def threadsafe_generator(f):

"""A decorator that takes a generator function and makes it thread-safe.

"""

def g(*a, **kw):

return threadsafe_iter(f(*a, **kw))

return g

@threadsafe_generator

def generate_data(tfrecord_file_path_list, ...):

dataset = tf.data.TFRecordDataset(tfrecord_file_path_list)

# example proto decode

def _parse_function(example_proto):

...

return batch_data

# Parse the record into tensors.

dataset = dataset.map(_parse_function)

dataset = dataset.shuffle(buffer_size=100000)

# Repeat the input indefinitly

dataset = dataset.repeat()

# Generate batches

dataset = dataset.batch(batch_size)

# Create an initializable iterator

iterator = dataset.make_initializable_iterator()

# Get batch data

batch_data = iterator.get_next()

iterator_init_op = iterator.make_initializer(dataset)

with tf.Session() as sess:

sess.run(iterator_init_op)

while True:

try:

batch_data = sess.run(batch_data)

except tf.errors.OutOfRangeError:

break

yield batch_data

好吧,可以这样进行讨论是否真的很优雅,但似乎运行得很好.

总结一下:

>如果在Windows上编写程序,请将use_multiprocessing设置为False.

>(据我所知,到今天为止)在Windows上编写代码时,不支持编写一个继承Sequence的生成器类. (我猜这是一个Tensorflow / Keras问题).

>要解决此问题,请编写一个普通的生成器,使生成器线程安全,并将worker设置为大于1的数字.

重要说明:在此设置中,生成器在CPU上运行,而训练在GPU上完成.我可以观察到的一个问题是,如果您正在训练的模型足够浅,则GPU的利用率仍然很低,而CPU利用率却很高.如果模型较浅并且数据集足够小,那么将所有数据存储在内存中并在GPU上运行所有数据是一个不错的选择.它应该大大加快培训速度.如果出于任何原因想要同时使用CPU和GPU,我的建议是尝试使用Tensorflow的tf.data API,该API可显着加快数据预处理和批处理的速度.如果生成器仅使用Python编写,则GPU会一直等待数据以继续训练.可以说有关Tensorflow / Keras文档的所有内容,但这确实是高效的代码!

如果您对API有更全面的了解,并且看到这篇文章,请随时在这里纠正我,以防万一我误解了任何东西,或者更新了API以解决问题,甚至在Windows上也是如此.

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值