Pytorch多进程睡死的大坑

  • Author:ZERO-A-ONE
  • Date:2021-03-09

最近在使用Pytorch编写一些多进程程序,遇到了一个大坑,就是Python常用的多进程库multiprocessing 在实现多进程的模式不同,对Pytorch程序的影响

一、起步

首先我写了如下的一段代码,使用了multiprocessing 的进程池的方法,想实现多进程的训练

import argparse
import time
import multiprocessing as mp
import CartPole
import Pendulum

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='CartPole-v0')
    parser.add_argument('--frame', type=int, default=10000)
    parser.add_argument('--proc', type=int, default=4)
    args = parser.parse_known_args()[0]
    return args

if __name__ == '__main__':
    args = get_args()
    ave_step = args.frame / args.proc
    if args.task == 'CartPole-v0':
        main = CartPole.main
    else:
        main = Pendulum.main
    time_start = time.time()
    pool = mp.Pool(args.proc)
    for i in range(args.proc):
        pool.apply_async(main, (ave_step,))
    pool.close()
    pool.join()
    time_end = time.time()
    print("time cost:%f" % (time_end - time_start))

在Windows的本机上进行测试并没有什么问题,然后在实验室的Linux服务器上悲剧的出现了子进程无法退出,卡在了pool.join()处,导致整个程序无法正常继续往下执行的错误,于是我最后查阅了很多资料修改成了如下的形式

import argparse
import time
import multiprocessing as mp
import CartPole
import Pendulum

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='CartPole-v0')
    parser.add_argument('--frame', type=int, default=10000)
    parser.add_argument('--proc', type=int, default=1)
    args = parser.parse_known_args()[0]
    return args
    
if __name__ == '__main__':
    mp.set_start_method("spawn")
    args = get_args()
    ave_step = args.frame / args.proc
    if args.task == 'CartPole-v0':
        main = CartPole.main
    else:
        main = Pendulum.main
    time_start = time.time()
    processes=[]
    for i in range(args.proc):
        p = mp.Process(target=main, args=(ave_step,))
        p.start()
        processes.append(p)
        
    for p in processes:
        p.join()

    time_end = time.time()
    print("time cost:%f" % (time_end - time_start))

大家可以仔细看看改动的地方,然后猜想一下造成这种情况的原因

二、线索

我们来回忆一下,Python有多少种创建子进程的方式

Python创建的子进程执行的内容,和启动该进程的方式有关。而根据不同的平台,启动进程的方式大致可分为以下 3 种:

  • spawn:使用此方式启动的进程,只会执行和target参数或者run()方法相关的代码,父进程启动一个新的Python解释器进程,子进程只继承运行run()方法所需的资源。来自父进程的不必要的文件描述符和句柄将不会被继承。Windows 平台只能使用此方法,事实上该平台默认使用的也是该启动方式。相比其他两种方式,此方式启动进程的效率最低
  • fork:使用此方式启动的进程,基本等同于主进程(即主进程拥有的资源,该子进程全都有)。因此,该子进程会从创建位置起,和主进程一样执行程序中的代码。注意,此启动方式仅适用于 UNIX 平台,os.fork() 创建的进程就是采用此方式启动的
  • forserver:使用此方式,程序将会启动一个服务器进程。即当程序每次请求启动新进程时,父进程都会连接到该服务器进程,请求由服务器进程来创建新进程。通过这种方式启动的进程不需要从父进程继承资源。注意,此启动方式只在UNIX平台上有效

总的来说,使用类 UNIX 平台,启动进程的方式有以上 3 种,而使用 Windows 平台,只能选用 spawn 方式(默认即可

Unix默认使用fork模式, windows 默认使用spawn

Python大致提供了以下两种手动设置进程启动方式的方法

2.1 set_start_method

multiprocessing模块提供了一个set_start_method() 函数,该函数可用于设置启动进程的方式。需要注意的是,该函数的调用位置,必须位于所有与多进程有关的代码之前

例如,下面代码演示了如何显式设置进程的启动方式:

import multiprocessing
import os
print("当前进程ID:",os.getpid())

# 定义一个函数,准备作为新进程的 target 参数
def action(name,*add):
    print(name)
    for arc in add:
        print("%s --当前进程%d" % (arc,os.getpid()))
if __name__=='__main__':
    #定义为进程方法传入的参数
    my_tuple = ("http://c.biancheng.net/python/",\
                "http://c.biancheng.net/shell/",\
                "http://c.biancheng.net/java/")
    #设置进程启动方式
    multiprocessing.set_start_method('spawn')
   
    #创建子进程,执行 action() 函数
    my_process = multiprocessing.Process(target = action, args = ("my_process进程",*my_tuple))
    #启动子进程
    my_process.start()

程序执行结果为:

当前进程ID: 24500
当前进程ID: 17300
my_process进程
http://c.biancheng.net/python/ --当前进程17300
http://c.biancheng.net/shell/ --当前进程17300
http://c.biancheng.net/java/ --当前进程17300

注意:由于此程序中进程的启动方式为 spawn,因此该程序可以在任意( Windows 和类 UNIX 上都可以 )平台上执行

2.2 get_context

除此之外,还可以使用 multiprocessing 模块提供的get_context()函数来设置进程启动的方法,调用该函数时可传入 “spawn”、“fork”、“forkserver” 作为参数,用来指定进程启动的方式

需要注意的一点是,前面在创建进程是,使用multiprocessing.Process()这种形式,而在使用 get_context()函数设置启动进程方式时,需用该函数的返回值,代替 multiprocessing 模块调用Process()

例如,下面程序演示了如何使用 get_context() 函数设置进程启动:

import multiprocessing
import os
print("当前进程ID:",os.getpid())

# 定义一个函数,准备作为新进程的 target 参数
def action(name,*add):
    print(name)
    for arc in add:
        print("%s --当前进程%d" % (arc,os.getpid()))
if __name__=='__main__':
    #定义为进程方法传入的参数
    my_tuple = ("http://c.biancheng.net/python/",\
                "http://c.biancheng.net/shell/",\
                "http://c.biancheng.net/java/")
    #设置使用 fork 方式启动进程
    ctx = multiprocessing.get_context('spawn')
   
    #用 ctx 代替 multiprocessing 模块创建子进程,执行 action() 函数
    my_process = ctx.Process(target = action, args = ("my_process进程",*my_tuple))
    #启动子进程
    my_process.start()

程序执行结果为:

当前进程ID: 18632
当前进程ID: 16700
my_process进程
http://c.biancheng.net/python/ --当前进程16700
http://c.biancheng.net/shell/ --当前进程16700
http://c.biancheng.net/java/ --当前进程16700

2.3 OpenMP

这里我们需要了解一个很有名的API:OpenMP。OpenMP(Open Multi-Processing)是一套支持跨平台共享内存方式的多线程并发的编程API,使用C,C++和Fortran语言,可以在大多数的处理器体系和操作系统中运行。简单来说,这是一个c/c++/Fortran等语言编译器的一个扩展,使得你不用写多线程代码,可以直接在原来代码上加上一行编译器看的懂得注释,编译器就会自动帮你多线程运行一些耗cpu的操作。在GCC中,它叫做libgomp

OpenMP的好处是显而易见的:

  • 方便:不用修改原来的代码,只需要添加一行类似的注释的东西
  • 安全:对于不支持OpenMP的编译器,会忽略该行,所以很安全

但是就如同世界上所有的事物都有两面性一样,OpenMP也是有缺点的,简单来说就是当你使用fork()时,如果父进程和子进程同时使用OpenMP,且父进程先使用OpenMP再调用fork(),则会造成子进程挂起

这是一个来自官网的例子:

#include <stdio.h>
#include <sys/wait.h>
#include <unistd.h>
void a()
{
	#pragma omp parallel num_threads(2)
    {
      puts("para_a"); // output twice
    }
    puts("a ended"); // output once
}
void b()
{
    #pragma omp parallel num_threads(2)
    {
      puts("para_b");
    }
    puts("b ended");
}
 int main() {
	a();   // Invokes OpenMP features (parent process)
	int p = fork();
	if(!p)
	{
		b(); // ERROR: Uses OpenMP again, but in child process
		_exit(0);
	}
	wait(NULL);
	return 0;
}

如上代码,父进程调用的函数a()中先使用了OpenMP,然后调用了fork(),那么子进程中b()中的"b ended"永远不会执行

按照官网的说法,这是一个无法解决的问题,“There is currently no workaround; the libgomp API does not specify functions that can be used to prepare for a call to fork().”

三、解密

我的这个问题的主要原因是torch.einsumnumpy.dottorch.matmul等等各种矩阵运算使用了OpenMP,复现了上面的问题,造成了子进程的挂起

其实Pytorch的子进程挂起的问题跟硬件,操作系统,线程数,矩阵规模都有关系,按照numpy的上的Github的第5752号Issues的看法

https://github.com/numpy/numpy/issues/5752

By default, Python multiprocessing does fork without exec which breaks various libraries that use posix thread pools (or other) internally (Accelerate, CUDA, libgomp the OpenMP implementation of gcc). This is probably still the case under some circumstances (e.g. data size) for Apple Accelerate although it seems to depend on the versions. The first time we observed the issue was on OSX 10.7. We (@cournape and I) reported the bug and Apple replied: wontfix, fork without exec is a POSIX standard violation (which is true BTW).

I issued a patch to OpenBLAS to make it’s non-OpenMP thread pool fork safe in the past:

xianyi/OpenBLAS#294

It would be interesting to retry if this fix still works for the latest version of OpenBLAS. As OpenBLAS is quite fast we could use it for the wheels (if all numpy + scipy tests pass under OSX with OpenBLAS).

ATLAS is robust by default too.

In Python 3.4+, multiprocessing has a forkserver start method to mitigate that issue but it is not used by default.

总结一下可能的情况:

  • 你使用了某个版本的apple accelerate
  • 某些情况下,data size过大也会导致这个现象
  • 主进程和子进程没有按照OpenMP规范使用(详细看下面的2.2)
  • 你在主进程初始化cuda,然后调用fork(),并在子进程中也是用cuda

3.1 解决办法

尝试使用spawn或者forkserver模式

例如:

import multiprocessing as mp
mp.set_start_method("spawn")  # 使用spqwn模式
# mp.set_start_method("forkserver")   # 使用forkserver模式

"""
这里执行多进程代码
"""

torch.set_num_threads(1)

  • 在父进程中创建子进程之前,执行torch.set_num_threads(1),并在子进程一开始也执行torch.set_num_threads(1)
    这种方式能解决问题,但是是讨巧的方式,其实它是限制了OpenMP在进程中使用多线程加速,属于magic方法,不提倡。但是有一个可能改进:“One possible improvement is to register a pthread_atfork handler that calls omp_set_num_threads(1) in the prepare and restores the value in parent and possibly child.”

使用基于OpenBLAS的库

四、参考文献

  • 8
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值