Spawning 子线程
仅支持 Python >= 3.4.
依赖于 spawn
启动方法(在 Python 的 multiprocessing
包中)。
通过创建进程
实例并调用join来等待它们完成,可以生成大量子进程来执行某些功能。这种方法在处理单个子进程时工作得很好,但在处理多个进程时可能会出现问题。
也就是说,顺序连接进程意味着它们将顺序终止。如果没有,并且第一个进程没有终止,那么进程终止将不被注意。 此外,没有用于错误传播的本地工具.
下面的spawn
函数解决了这些问题,并负责错误传播、无序终止,并在检测到其中一个错误时主动终止进程.
torch.multiprocessing.spawn(fn, args=(), nprocs=1, join=True, daemon=False)
Spawns nprocs
进程运行 fn
使用参数 args
.
如果其中一个进程以非零退出状态退出,则会杀死其余进程,并引发异常,导致终止。在子进程中捕获异常的情况下,将转发该异常,并将其跟踪包含在父进程中引发的异常中。
参数:
-
fn (function) –
函数被称为派生进程的入口点。必须在模块的顶层定义此函数,以便对其进行pickle和派生。这是多进程强加的要求。
该函数称为
fn(i, *args)
,其中i
是进程索引,args
是传递的参数元组。 -
args (tuple) – 传递给
fn
的参数. -
nprocs (int) – 派生的进程数.
-
join (bool) – 执行一个阻塞的join对于所有进程.
-
daemon (bool) – 派生进程守护进程标志。如果设置为True,将创建守护进程.
class torch.multiprocessing.SpawnContext
由 spawn()
返回, 当 join=False
.
join(timeout=None)
尝试连接此派生上下文中的一个或多个进程。如果其中一个进程以非零退出状态退出,则此函数将杀死其余进程,并引发异常,导致第一个进程退出。
返回 True
如果所有进程正常退出, False
如果有更多的进程需要 join.
使用例子:
参考:
https://github.com/sangho-vision/wds_example/blob/850fdff046e4b84215722d291ffad8c024062607/run.py
import utils.multiprocessing as mpu
if cfg.NUM_GPUS > 1:
torch.multiprocessing.spawn(
mpu.run,
nprocs=cfg.NUM_GPUS,
args=(
cfg.NUM_GPUS,
train,
cfg.DIST_INIT_METHOD,
cfg.SHARD_ID,
cfg.NUM_SHARDS,
cfg.DIST_BACKEND,
cfg,
),
daemon=False,
)
multiprocessing.py
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Multiprocessing helpers."""
import torch
def run(
local_rank, num_proc, func, init_method, shard_id, num_shards, backend, cfg
):
"""
Runs a function from a child process.
Args:
local_rank (int): rank of the current process on the current machine.
num_proc (int): number of processes per machine.
func (function): function to execute on each of the process.
init_method (string): method to initialize the distributed training.
TCP initialization: equiring a network address reachable from all
processes followed by the port.
Shared file-system initialization: makes use of a file system that
is shared and visible from all machines. The URL should start with
file:// and contain a path to a non-existent file on a shared file
system.
shard_id (int): the rank of the current machine.
num_shards (int): number of overall machines for the distributed
training job.
backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
supports, each with different capabilities. Details can be found
here:
https://pytorch.org/docs/stable/distributed.html
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
"""
# Initialize the process group.
world_size = num_proc * num_shards
rank = shard_id * num_proc + local_rank
try:
torch.distributed.init_process_group(
backend=backend,
init_method=init_method,
world_size=world_size,
rank=rank,
)
except Exception as e:
raise e
torch.cuda.set_device(local_rank)
func(cfg)