20220917 -
最近在做论文相关的实验,虽然深度学习部分由底层框架来实现,但是在后续的机器学习实验中,由于涉及到一些算法不能并行跑,所以导致速度比较慢。比如测试10个参数,那么跑十次,速度就下降很多。考虑到使用的机器CPU数量还比较多,那么就将其改变为多进程的执行方式。
这部分编程,我记得在几年前在进行日志处理的时候弄过这部分内容,当时也是采用多进程以及多线程的方式。当时是利用类定义的方式,定义队列等内容,逐步进行处理。这次不用这么麻烦,直接采用pool这个函数来实现,官方文档的代码如下:
from multiprocessing import Pool
def f(x):
return x*x
if __name__ == '__main__':
with Pool(5) as p:
print(p.map(f, [1, 2, 3]))
但是这里遇到了两个问题,一个是我为了图方便,在函数内容定义了一个子函数,以此来实现变量共享,但是其实我写的时候就觉得可能不太对,但是也就直接写了,这个时候报错信息是
“AttributeError: Can’t pickle local object in Multiprocessing”
看了这个信息,本质上也明白了,就是因为在内部定义了函数,所以他没办法将这部分函数传递到别的进程。那么只需要将这部分改到外面即可,但是也引发了另外一个问题,多个参数怎么办?由此搜索到另外一个解答[1],其实按照队列的方式,这就是一个很简单的问题,直接将这部分传递进去。或者,我觉得,直接将他包装成一个字典也都可以解决。但实际上,看文章[1]就可以看到,官方也有自己的解决方案。
#!/usr/bin/env python3
from functools import partial
from itertools import repeat
from multiprocessing import Pool, freeze_support
def func(a, b):
return a + b
def main():
a_args = [1,2,3]
second_arg = 1
with Pool() as pool:
L = pool.starmap(func, [(1, 1), (2, 1), (3, 1)])
M = pool.starmap(func, zip(a_args, repeat(second_arg)))
N = pool.map(partial(func, b=second_arg), a_args)
assert L == M == N
if __name__=="__main__":
freeze_support()
main()
以上是python3的版本,python2的话,有所不同
#!/usr/bin/env python2
import itertools
from multiprocessing import Pool, freeze_support
def func(a, b):
print a, b
def func_star(a_b):
"""Convert `f([1,2])` to `f(1,2)` call."""
return func(*a_b)
def main():
pool = Pool()
a_args = [1,2,3]
second_arg = 1
pool.map(func_star, itertools.izip(a_args, itertools.repeat(second_arg)))
if __name__=="__main__":
freeze_support()
main()
这样的话,就可以把这个问题解决了,目前通过这种方式,能直接缩短2/3的时间。
参考
[1]How to use multiprocessing pool.map with multiple arguments