PyTorch分布式训练指北
单卡用习惯了,多卡训练前的参数配置感觉还是有点麻烦,想着记录一下多卡分布的环境配置和cuda初始化方式
环境支持
首先得有torch和对应版本的cuda驱动,确保项目的requirements或者environment已经配置好。我们一般都是使用裸的算力、显卡以及服务器,需要先手动配置一下nccl
git clone https://github.com/NVIDIA/nccl-tests.git
cd nccl-tests
make
./build/all_reduce_perf -b 8 -e 256M -f 2 -g <YOUR GPU AMOUNT>
编译nccl环境,这个过程应该很快,最后一行是为了测试一下,-g后面跟你的机器(GPU)数量
这步编译很快,然后我们看项目代码怎么结合分布式
分布式初始化
我也参考了好几个分布式训练的项目,总结了一下
需要一个rank变量,其实是一个gpu编号,代码里只用设置一个0的值就可以(或者从输入流中获取),它在分布式里类似于一个循环变量的东西,运行起来时会在环境变量CUDA_VISIBLE_DEVICES中,使用它做一个初始化;实际上我们更多的使用环境变量(设置方便不用重新commit),设置环境变量
export RANK=0,1
export CUDA_VISIBLE_DIVICE=0,1
(双卡)
设置有效gpu数会在后面调用torch.cuda.device_count()的地方被隐式调用
import torch.multiprocessing as mp
import torch.distributed as dist
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend='nccl')
torch.backends.cudnn.benchmark = True
device = 'cuda'
分布式通讯是通过localhost的一个端口实现的,如果运行上述代码时只有单个机器(没有多卡),可能会在torch.distributed.init_process_group(backend='NCCL')处被阻塞(torch包中的迭代器找不到机器编号),ctrl+C杀不死,查看全部进程
ps aux
找到pid然后
kill -9 <PID>
模型的分布式
一个可以被多卡分布式使用的模型,必须是nn.Module的子类
device = 'cuda'
single_gpu = False
model = MODEL().to(device)
model = MODEL if single_gpu else torch.nn.parallel.DistributedDataParallel(
MODEL,
device_ids=[rank], # 接前面的rank
output_device=rank)
这里的MODEL必须是nn.Module的子类(使用分布式时,上面single_gpu设置为False即可,model获取到的是(DataParallel, DistributedDataParallel)对象)
注意上述device='cuda',不需要指定cuda编号,这个由分布式的多线程自动完成映射(分布式相当于是同时运行多个python源程序,将batch上的数据分配到不同的GPU上计算,最后算loss的时候能够再将所有GPU上的数据汇合)
后续的代码,使用MODEL(nn.Module)模型原型时调用model.module(<Input>)即可
运行分布式
像前面提到的在运行的终端里设置好环境变量
export RANK=0,1,2
export CUDA_VISIBLE_DEVICES=0,1,2
启动带有分布式的程序时运行
python -m torch.distributed.launch --nproc_per_node=3 --master_port=1010 YourCode.py
其中--nproc_per_node表示参与多卡分布式计算的GPU数量,--master_port表示GPU之间通信的localhost端口(其实也可以不用localhost通信,也可以用远程端口通讯,这就涉及到服务器集群了)。
注意,如果直接使用
python YourCode.py
启动,也会遇到前面说到的分布式阻塞问题,也只能kill掉相应的PID;注意一定要用python -m torch.distributed.lauch来启动。
我记得以前莫名其妙遇到一个多卡分布式socket报错(报错内容大致是:非阻塞方法调用了阻塞方法),但后来就没见到了,以后碰到了再补进来。