用multiprocessing.Queue自己编写进程池实现DataLoader(父子进程通信)

目录

1. 用Queue写进程池从而实现父子进程通信

2. 自己编写DataLoader

3. 关于多进程杀掉主进程会有子进程残留的问题


multiprocess.Queue是多进程安全的队列,可以使用Queue实现多进程之间的数据传递, 底层使用管道pipe,同步信号量和互斥锁实现。

multiprocess库提供了pool进程池,直接实现了多进程之间的通信,pool的使用场景也很多,这里不做介绍。但pool不能实现父子进程之间的通信,要想实现父子通信,需要自己用Queue写一个进程池,通过创建子队列和父队列来进行父子通信。这个用途就更广泛了,比如想要开多个子进程帮你处理一些数据或文件,最后把结果都收回到主进程,这也就是自己编写DataLoader,很多实际项目中需要自己写一个数据预处理代码,所以要自己重写DataLoader。

本文先讲如何实现父子进程通信,之后就很容易实现DataLoader了。


1. 用Queue写进程池从而实现父子进程通信

思路其实很简单,先分别创建子队列父队列,然后在主进程中把文件放入子队列,多个子进程各自从子队列中取文件并进行处理,再把各自处理好的数据放入父队列,主进程等所有文件处理完后从父队列取处理好的数据。

这里比较需要注意的是用None作为退出信号,因为如果是用Queue的get()函数(默认block=False)则会在队列为空时阻塞等待而不退出;如果用get_nowait()也即get(block=True),则会在队列为空时报出异常;设置成当队列为空则退出也是不合理的。因此可以考虑用None作为退出信号,由主进程在文件都放入子队列后发出。

这里省略了具体处理data的函数。由于比较容易理解所以直接上代码。

from multiprocessing import Process, Queue


def data_preprocess(file, args):
    # preprocess data
    pass

def preprocess_module(queue_in, queue_res, args):
    while True:
        file = queue.get()
        # exit if getting terminating signal
        if file is None:
            break
        instance = data_preprocess(file, args)
        queue_res.put(instance)
    # put None into queue_res as terminating signal
    queue_res.put(None)


if __name__ == '__main__':
    nthread = 2
    data_path = ' '
    files = os.listdir(data_path)

    queue_in = Queue(nthread)  # child queue, maxlength of Queue equals to nthread.
    queue_res = Queue()  # parant queue

    # create multiprocess
    processes = [Process(target=perprocess_module, args=(
                 queue_in, queue_res, args)) for _ in range(nthread)]
    
    # start multiprocess
    for each in processes:
        # terminate all child processes when parant process normally exits
        each.daemon = True
        each.start()

    # feed files to multiprocess 
    for file in files:
        queue_in.put(file)
    
    # put None into queue_in as terminating signal
    for i in range(args.core_num):
        queue_in.put(None)

    # parant process fetches results
    cnt_None = 0
    while True:
        if cnt_None == args.core_num:
            break
        t = queue_res.get()
        if t is None:
            cnt_None += 1
        else:
            res_list.append(t)

    # join multiprocess
    try:
        for each in processes:
            each.join()
    except Exception as e:
        print(str(e))

2. 自己编写DataLoader

重写torch的DataLoader只需要在self.init中执行上述的多进程处理流程,在def __getitem__(self):中返回处理好的结果即可。

class Dataset(torch.utils.data.Dataset):
    def __init__(self, args, batch_size):
        data_dir = args.data_dir
        nthread = args.core_num
        self.res_list = []
        

        files = os.listdir(data_dir)
        queue_in = Queue(nthread)  # child queue, maxlength of Queue equals to nthread.
        queue_res = Queue()  # parant queue

        # create child processes
        processes = [Process(target=perprocess_module, args=(
                     queue_in, queue_res, args)) for _ in range(nthread)]
    
        # start child processes
        for each in processes:
            # terminate all child processes when parant process normally exits
            each.daemon = True
            each.start()

        # feed files to child processes
        for file in files:
            queue_in.put(file)

        # put None into queue_in as terminating signal
        for i in range(args.core_num):
            queue_in.put(None)
    
        # parant process fetches results
        cnt_None = 0
        while True:
            if cnt_None == args.core_num:
                break
            t = queue_res.get()
            if t is None:
                cnt_None += 1
            else:
                res_list.append(t)

        # join child processes
        try:
            for each in processes:
                each.join()
        except Exception as e:
            print(str(e))

    def __len__(self):
        return len(self.res_list)

    def __getitem__(self, idx):
        instance = self.res_list[idx]
        return instance

要特别注意的是,Process库所执行的这个target_function,也即这行代码中的preprocess_module,不能定义在这行代码所在的域内,就是说要么放在整个文件的最外层即if __name__ == '__main__':之外,或者单独写成类的一个方法self.preprocess_module,否则会报错AttributeError:Can't pickle local object 'get_dataset.<locals>.Dataset

processes = [Process(target=perprocess_module, args=(
            queue_in, queue_res, args)) for _ in range(nthread)]

3. 关于多进程杀掉主进程会有子进程残留的问题

当用multiprocess.Pool和torch.multiprocessing.spawn实现多进程时用ctrl+c杀掉主进程是无法同时杀掉子进程的,这些子进程会成为僵尸进程,也即没有父进程管理它们,只能用查看进程的方式一个个杀掉僵尸进程。注意:daemon=True在官方文档中说的是开启子进程守护只是在父进程"正常退出"的情况下会回收所有子进程,如果中途用ctrl+c终止进程是属于异常退出,所以还是有僵尸进程。

for each in processes:
    # terminate all child processes when parant process normally exits
    each.daemon = True

神奇的是用Process自己编写进程池是可以在ctrl+c的情况下回收子进程的,具体原理我也不清楚,只是在实际使用中是这样。不过有时还是要用到Pool和spawn,所以这里放上我试过比较好用的解决方法。这部分的原因和解决原理我也没有深入研究,这里也是方便自己记录。

def term(sig_num, addtion):
    print('current pid is %s, group id is %s' % (os.getpid(), os.getpgrp()))
    os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)

if __name__ == '__main__':
    signal.signal(signal.SIGTERM, term)

    for p in processes:
        p.start()
    try:
        for p in processes:
            p.join()
    except Exception as e:
        print(str(e))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值