目录
multiprocess.Queue是多进程安全的队列,可以使用Queue实现多进程之间的数据传递, 底层使用管道pipe,同步信号量和互斥锁实现。
multiprocess库提供了pool进程池,直接实现了多进程之间的通信,pool的使用场景也很多,这里不做介绍。但pool不能实现父子进程之间的通信,要想实现父子通信,需要自己用Queue写一个进程池,通过创建子队列和父队列来进行父子通信。这个用途就更广泛了,比如想要开多个子进程帮你处理一些数据或文件,最后把结果都收回到主进程,这也就是自己编写DataLoader,很多实际项目中需要自己写一个数据预处理代码,所以要自己重写DataLoader。
本文先讲如何实现父子进程通信,之后就很容易实现DataLoader了。
1. 用Queue写进程池从而实现父子进程通信
思路其实很简单,先分别创建子队列父队列,然后在主进程中把文件放入子队列,多个子进程各自从子队列中取文件并进行处理,再把各自处理好的数据放入父队列,主进程等所有文件处理完后从父队列取处理好的数据。
这里比较需要注意的是用None作为退出信号,因为如果是用Queue的get()函数(默认block=False)则会在队列为空时阻塞等待而不退出;如果用get_nowait()也即get(block=True),则会在队列为空时报出异常;设置成当队列为空则退出也是不合理的。因此可以考虑用None作为退出信号,由主进程在文件都放入子队列后发出。
这里省略了具体处理data的函数。由于比较容易理解所以直接上代码。
from multiprocessing import Process, Queue
def data_preprocess(file, args):
# preprocess data
pass
def preprocess_module(queue_in, queue_res, args):
while True:
file = queue.get()
# exit if getting terminating signal
if file is None:
break
instance = data_preprocess(file, args)
queue_res.put(instance)
# put None into queue_res as terminating signal
queue_res.put(None)
if __name__ == '__main__':
nthread = 2
data_path = ' '
files = os.listdir(data_path)
queue_in = Queue(nthread) # child queue, maxlength of Queue equals to nthread.
queue_res = Queue() # parant queue
# create multiprocess
processes = [Process(target=perprocess_module, args=(
queue_in, queue_res, args)) for _ in range(nthread)]
# start multiprocess
for each in processes:
# terminate all child processes when parant process normally exits
each.daemon = True
each.start()
# feed files to multiprocess
for file in files:
queue_in.put(file)
# put None into queue_in as terminating signal
for i in range(args.core_num):
queue_in.put(None)
# parant process fetches results
cnt_None = 0
while True:
if cnt_None == args.core_num:
break
t = queue_res.get()
if t is None:
cnt_None += 1
else:
res_list.append(t)
# join multiprocess
try:
for each in processes:
each.join()
except Exception as e:
print(str(e))
2. 自己编写DataLoader
重写torch的DataLoader只需要在self.init中执行上述的多进程处理流程,在def __getitem__(self):中返回处理好的结果即可。
class Dataset(torch.utils.data.Dataset):
def __init__(self, args, batch_size):
data_dir = args.data_dir
nthread = args.core_num
self.res_list = []
files = os.listdir(data_dir)
queue_in = Queue(nthread) # child queue, maxlength of Queue equals to nthread.
queue_res = Queue() # parant queue
# create child processes
processes = [Process(target=perprocess_module, args=(
queue_in, queue_res, args)) for _ in range(nthread)]
# start child processes
for each in processes:
# terminate all child processes when parant process normally exits
each.daemon = True
each.start()
# feed files to child processes
for file in files:
queue_in.put(file)
# put None into queue_in as terminating signal
for i in range(args.core_num):
queue_in.put(None)
# parant process fetches results
cnt_None = 0
while True:
if cnt_None == args.core_num:
break
t = queue_res.get()
if t is None:
cnt_None += 1
else:
res_list.append(t)
# join child processes
try:
for each in processes:
each.join()
except Exception as e:
print(str(e))
def __len__(self):
return len(self.res_list)
def __getitem__(self, idx):
instance = self.res_list[idx]
return instance
要特别注意的是,Process库所执行的这个target_function,也即这行代码中的preprocess_module,不能定义在这行代码所在的域内,就是说要么放在整个文件的最外层即if __name__ == '__main__':之外,或者单独写成类的一个方法self.preprocess_module,否则会报错AttributeError:Can't pickle local object 'get_dataset.<locals>.Dataset
processes = [Process(target=perprocess_module, args=(
queue_in, queue_res, args)) for _ in range(nthread)]
3. 关于多进程杀掉主进程会有子进程残留的问题
当用multiprocess.Pool和torch.multiprocessing.spawn实现多进程时用ctrl+c杀掉主进程是无法同时杀掉子进程的,这些子进程会成为僵尸进程,也即没有父进程管理它们,只能用查看进程的方式一个个杀掉僵尸进程。注意:daemon=True在官方文档中说的是开启子进程守护只是在父进程"正常退出"的情况下会回收所有子进程,如果中途用ctrl+c终止进程是属于异常退出,所以还是有僵尸进程。
for each in processes:
# terminate all child processes when parant process normally exits
each.daemon = True
神奇的是用Process自己编写进程池是可以在ctrl+c的情况下回收子进程的,具体原理我也不清楚,只是在实际使用中是这样。不过有时还是要用到Pool和spawn,所以这里放上我试过比较好用的解决方法。这部分的原因和解决原理我也没有深入研究,这里也是方便自己记录。
def term(sig_num, addtion):
print('current pid is %s, group id is %s' % (os.getpid(), os.getpgrp()))
os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
if __name__ == '__main__':
signal.signal(signal.SIGTERM, term)
for p in processes:
p.start()
try:
for p in processes:
p.join()
except Exception as e:
print(str(e))