Python3 multiprocessing joinable queue 模板

最近需要在服务器上处理一批文件,每个文件的处理过程很简单,基本就是读入文件,计算一些统计值,然后把统计值汇总。一想这可以多线程啊老铁!调试了一下Python3的multiprocessing,这里留下一个模板以备之后使用。

程序运行的逻辑是这样的

  • 主进程扫描需要处理的文件,生成文件列表。
  • 主进程创建job队列,result队列和log队列。此时队列都为空。
  • 主进程创建log进程。该进程负责输出log。
  • 主进程创建所有子worker进程。子进程启动。监听来自job队列的信息(blocked get())。
  • 主进程创建result queue进程。该进程负责处理worker进程发送来的结果信息。
  • 主进程将文件列表的内容逐一put到job队列内(JoinableQueue)。worker进程get来自job队列的信息,处理文件,将处理结果put到result队列。
  • result进程从result队列get数据并临时存储。
  • 主进程将所有job发送完毕。主进程join job队列。
  • result进程处理完毕所有结果信息,自动退出。
    • 主进程发送终止标志给所有子进程。
  • 主进程join worker进程。
  • 主进程join result进程。
  • 主进程join log queue。
  • 主进程发送终止标志给log进程。
  • 主进程join log进程。
  • 主进程退出。

调试过程还是比较顺利的,之前也简单使用过python自带的多线程工具。期间遇到一个问题,就是主进程先发送了终止标志给子进程,然后才开始从result队列获取数据,导致主进程在从result队里get数据时形成无限block。其原理基本是这样的:

  • 子进程通过result队列put信息,当result队列已经有过多尚未get的数据时,子进程put的信息被一个pipe缓冲起来,等待result队列有更多空间时再转入队列。
  • 若主进程没有及时从result队列get数据,导致result队列有尚未进入的缓冲数据,并且主进程发送终止标志给子进程,子进程在未完成result队列的put的情况下退出,导致数据丢失,数据没有及时进入result队列。

Lesson learned: 利用get() 处理队列的进程要保持开启直到所有需要put()的数据都已put并且get到队列为空,此时才能安全地终止调用put()的进程。

此外,Python3对Queue package的命名与Python2不同,处理异常时需要注意。

以下是模板源码。


# Author: Yaoyu Hu <yyhu_live@outlook.com>

import argparse
import multiprocessing
import time

import sys
if ( sys.version_info[0] == 2 ):
    import Queue as queue
else:
    import queue

class Printer(object):
    def __init__(self, name, q=None):
        super(Printer, self).__init__()
        self.name = name
        self.q = q

    def __call__(self, s, level=0):
        line = '%s: %s' % (self.name, s)
        if ( self.q is None ):
            print(line)
        else:
            self.q.put([level, line], block=True)

def worker_log(name, logQ, conn, level=0):
    '''
    Send a string "exit" through conn to stop this process.

    Arguments: 
    name (str): Name of this process.
    logQ (queue): Log quque.
    conn (pipe connection): The pipe connection for receiving commands.
    level (int): Log level.
    '''

    printer = Printer(name)
    
    flagExit = False
    while (True):
        if (conn.poll()):
            command = conn.recv()

            printer('\"%s\" command received.' % (command))

            if ('exit' == command):
                flagExit = True

        try:
            logList = logQ.get(False)

            if ( logList[0] >= level ):
                print( logList[1] )

            logQ.task_done()
        except queue.Empty as exp:
            if ( flagExit ):
                break

    printer('Exit.')

def worker_rq(name, resQ, logQ, nJobs, 
            printInterval=10, waitWhenEmpty=1, timeoutCountLimit=100):
    '''
    This worker process handles the data in the result queue. It first blocks for
    1s for result to come in. If there is no result to process, it inceases a timeout counter
    to record how many times it has not got a result. It keeps trying to get a result
    until a result comes before too many timeouts happen. Once a new result is processed, 
    the timeout counter resets. The top limit of continuous timeouts is defined by the argument
    timeoutCountLimit. 

    If timeoutCountLimit is reached, then this process will terminate itself.

    Arguments: 
    name (str): Name of this process.
    resQ (queue): Result queue.
    logQ (queue): Log queue.
    nJobs (int): Expected number of jobs.
    printInterval (int): Interval of print.
    waitWhenEmpty (float): Time to wait if the result queue is empty.
    timeoutCountLimit (int): Toip limit for continuous timeout. 
    '''
    # The printer.
    printer = Printer(name, logQ)
    
    resultList   = []
    resultCount  = 0
    timeoutCount = 0

    flagOK = True
    while(resultCount < nJobs):
        try:
            r = resQ.get(block=True, timeout=1)
            resultList.append(r)
            resultCount += 1

            timeoutCount = 0

            if (resultCount % printInterval == 0):
                printer('worker_rq collected %d results.' % (resultCount))
        except Empty as exp:
            printer('Wait on rq-index %d.' % (resultCount))
            timeoutCount += 1
            time.sleep(waitWhenEmpty)

            if ( timeoutCount == timeoutCountLimit ):
                printer('worker_rq reaches the timeout count limit (%d). Process abort.' % \
                    (timeoutCountLimit))
                flagOK = False
                break

    if (flagOK):
        printer('All results processed with no error.')
    else:
        printer('Not all results are received.')

    printer('resultCount = %d, nJobs = %d.' % (resultCount, nJobs))

def process_single_file(name, jobDict):
    '''
    Arguments:
    name (str): Name of the process.
    jobDict (dict): Job dictionary.
    '''

    startTime = time.time()
    idx = jobDict['idx']
    s = jobDict['s']
    print('%d, %s' % (idx, s))
    time.sleep(0.5)

    endTime = time.time()
    
    s ='%s: %ds for processing.' % (name, endTime - startTime )

    return s

def worker( name, jobQ, conn, resQ, logQ, 
        jobQBlockTime=1 ):
    '''
    Send a string "exit" through conn to stop this process.

    Arguments:
    name (str): Name of this worker process.
    jobQ (queue): Job queue.
    comm (pipe connection): Pipe connection for receiving commands.
    resQ (queue): Result queue.
    logQ (queue): Log queue.
    jobQBlockTime (float): Block time for getting a job from jobQ.
    '''

    # The printer.
    printer = Printer(name, q=logQ)
    printer('Worker starts.')

    while (True):
        if (conn.poll()):
            command = conn.recv()

            printer('\"%s\" command received.' % (command))

            if ('exit' == command):
                break

        try:
            jobDict = jobQ.get(True, jobQBlockTime)
            res = process_single_file(name, jobDict)

            resQ.put([res], block=True)

            jobQ.task_done()
        except queue.Empty as exp:
            pass
    
    printer('Exit.')

def handle_args():
    parser = argparse.ArgumentParser(description='Filter the files.')

    parser.add_argument('--jobs', type=int, default=10, \
        help='The number of jobs for testing.')

    parser.add_argument('--np', type=int, default=2, \
        help='The number of processes.')

    args = parser.parse_args()

    assert args.jobs > 0
    assert args.np > 0

    return args

if __name__ =='__main__':
    # ========== Handle arguments. ==========
    args = handle_args()

    # Create a Printer object.
    printer = Printer('Main')

    startTime = time.time()

    printer('Main process.')

    # ========== The queues. ==========
    logQ    = multiprocessing.JoinableQueue()
    jobQ    = multiprocessing.JoinableQueue()
    resultQ = multiprocessing.Queue()

    # ========== Log queue worker. ==========
    [ connLQ1, connLQ2 ] = multiprocessing.Pipe(False)
    pWorkerLQ = multiprocessing.Process( 
        target=worker_log, 
        args=['Log', logQ, connLQ1, 0] )
    pWorkerLQ.start()

    # Update the printer.
    printer.q = logQ

    # ========== Processing workers. ==========
    processes = []
    pipes     = []
    printer('Create %d processes.' % (args.np))
    for i in range(int(args.np)):
        [conn1, conn2] = multiprocessing.Pipe(False)
        processes.append( multiprocessing.Process( \
            target=worker, args=['P%03d' % (i), jobQ, conn1, resultQ, logQ]) )
        pipes.append(conn2)

    for p in processes:
        p.start()

    # ========== Result queue worker. ==========
    pWorkerRQ = multiprocessing.Process( 
        target=worker_rq, 
        args=['RQ', resultQ, logQ, args.jobs] )
    pWorkerRQ.start()

    printer('All processes started.')

    # ========== Submit jobs. ==========
    for dj in range(args.jobs):
        jobQ.put( { 'idx': dj, 's': 'Job string. ' } )

    printer('All jobs submitted.')

    # Wait all the jobs to be finished.
    jobQ.join()
    printer('Job queue joined.')

    # ========== Stop all the processes. ==========
    # Stop all the processing workers. 
    for p in pipes:
        p.send('exit')

    printer('Exit command sent to all processes.')

    for p in processes:
        p.join()

    printer('All processes joined.')

    # Stop the result queue worker. 
    pWorkerRQ.join()
    printer('The result queue worker is joined.')

    # Sopt the log queue worker. 
    logQ.join()
    connLQ2.send('exit')
    pWorkerLQ.join()

    print('Main: Log queue joined. Log queue worker stoped.')

    # ========== Final info. ==========
    endTime = time.time()

    print('Main: Job done. Total time is %ds.' % (endTime - startTime))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值