python中的多进程与多线程 提升keras训练时数据准备的速度

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。





1. 测试状态下,如果用predict函数,则模型会重新编译加载,可以使用中间层输出,来计算真正的测试阶段前向传播所用的时间。点击打开链接

2. 训练状态下,GPU提速的效果与在一个mini-batch上计算的时间和权值同步的时间有关 点击打开链接


It seems that not all models will benefit for multi_gpu_model.
Different models have different scalability due to the overhead of weight synchronization.

ResnetV1 and ResnetV2 are a pair of typical examples to prove that, while ResnetV2 have a better scalability than ResnetV1.

There is a balance between training one mini-batch and weights synchronization. InceptionV3 has heavy computational cost on training one mini-batch while it has sparse weights need to synchronize, so this model will gain a decent boost on multi_gpu_model.

However, any model with large Dense layer usually contributes to a bad scalability, just like mnist_mlp, which have a light computational cost on training one mini-batch while its weights are too large to synchronize efficiently, so in the example of mnist_mlp, the time spent to do one weights synchronization is even able to finish training MANY turns of mini-batch by single GPU, so mnist_mlp will not benefit for multi_gpu_model due to its dense network design to result in a bad scalability.

It also indicates that models to train on different GPU architectures will also have a different answer about whether it will benefit for multi_gpu_model, since it largely depends on whether the GPU is fast enough to perform training one mini-batch than a weights synchronization. So another conclusion is that the faster a GPU is, the less likely that multi_gpu_model can boost your model.


3. 通道顺序会影响keras的训练速度。 NCWH和NWHC不同的数据format会影响keras的performance. 点击打开链接

4. fit_generator()中的进程和线程。 点击打开链接

下面提到的pickle_safe在新版的keras中被重新命名为: 'pickle_safe' ---> 'use_multiprocessing'

'pickle_safe', 'use_multiprocessing'


Keras Tips & Tricks


  • Can using either threading or multiprocessing for concurrent and parallel processing, respectively, of the data generator.
  • In the threading approach (model.fit_generator(..., pickle_safe=False)), the generator can be run concurrently (but not parallel) in multiple threads, with each thread pulling the next available batch based on the shared state of the generator and placing it in a shared queue. However, the generator must be threadsafe (i.e. use locks at synchronization points).
  • Due to the Python global interpreter lock (GIL), the threading option generally does not benefit from >1 worker (i.e. model.fit_generator(..., nb_worker=1) is best). One possible use case in which >1 threads could be beneficial is the presence of exceptionally long IO times, during which the GIL will be released to enable concurrency. Note also that TensorFlow's method also releases the GIL, thus allowing an actual thread to be run in parallel to a training iteration. To achieve the best performance with this approach, the wall-clock time for generating a single batch with the data generator should be less than that of a single training iteration. Otherwise, the data generation will become a bottleneck.
  • In the multiprocessing approach (model.fit_generator(..., pickle_safe=True)), the generator can be copied and run in parallel by multiple processes. The key thing to realize here is that the multiprocessing approach will (by default) fork the current process for each worker process, and thus, each process will effectively start with a copy of the generator. Subsequently, each process will run their own "copy" of the generator in parallel, with no synchronization. Thus, while any generator will run without error, if one is not careful, this approach can result in the processes generating the exact same batches at (basically) the same time, (i.e. a deterministic generator will be evaluated in the same manner in each process). The issue here is that with n processes, the model may see the same batch for n consecutive steps, and an "epoch" may actually consist of total_batches/n number of unique batches, rather than total_batches. To fix this, the generator can be reformulated in a manner that relies on NumPy random numbers for generating batches, as the GeneratorEnqueuer class will set a random seed with np.random.seed() for each process.
  • Due to the overhead of (de)serialization, the multiprocessing option generally only benefits from >1 worker (i.e. model.fit_generator(..., nb_worker=8)), and will generally result in much better performance than the threading option.


  • In light of the above information, the ImageDataGenerator is "threadsafe", so it can be used with the "threading" approach above. However, it is not completely appropriate for the "multiprocessing" approach due to issue described above, despite not throwing any errors. If a ImageDataGenerator generator is used with the multiprocessing approach above, the behavior is such that the first epoch will suffer from the problem of the same, deterministic generator in each process, and thus the same batches will be produced at (basically) the same time by each process. If shuffling is used, then the generators will diverge in subsequent epochs due to the Iterator superclass randomly permuting the indices with index_array = np.random.permutation(n) at the start of each epoch, making use of the random seed set for each process.

5. pyhon中的多进程与多线程 点击打开链接

parallel python模块 点击打开链接  或者 multi-processing模块

pp(parallel python)模块和multiprocessing模块都可以实现并行运算,实现的具体形式好像不同,pp模块更加简单 点击打开链接