pytorch的ddp分布式训练
相对于单卡,多卡分布式训练能够:1.处理更大的模型的数据集,2.充分利用实验室硬件资源,3.以及最重要的加速训练速度,目前已经被广泛使用。本文用于简化pytorch的ddp训练代码,并解释每一步的作用。
导包
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
torch.distributed
包是用于支持分布式深度学习的核心包。它提供了其中包括初始化分布式环境、同步张量、发送和接收消息等功能函数。torch.multiprocessing
是PyTorch中的多进程包,用于启动和管理多个进程。DistributedDataParallel
是PyTorch的分布式训练工具,用于复制模型到不同GPU上,训练并同步梯度和参数。torch.utils.data.distributed.DistributedSampler
是数据采样器(sampler)类,确保不同GPU或节点上的模型训练使用不同的数据子集,避免数据重复使用。
初始化进程组
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group('nccl', rank = rank, world_size = world_size)
-
MASTER_ADDR
和MASTER_PORT
是两个待设置的环境变量,指定进程主节点的地址和端口 -
nccl
为通信后端,rank为当前分配的进程的标识,world_size表示总进程数
划分和分发数据
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=True)
batch_size = math.ceil(self.batch_size / self.world_size)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size, shuffle=False, pin_memory=True)
- sampler的rank和num_replicas分别用于指定当前进程和总参与的进程数
- batch_size为设置数/总参与进程数,如设置为8,有4个GPU参与,每个GPU负责2个batch
- 因为sampler已经打乱了数据顺序,dataloader的参数shuffle=False
包装为DDP模型
ddp_kwargs = {'device_ids': [rank], 'find_unused_parameters': True}
model_ddp = DDP(model, **ddp_kwargs)
device_ids
参数指定了在哪个GPU上运行DDP
模型find_unused_parameters
参数告诉DDP
是否查找未使用的参数。模型中有一部分参数不参与更新则可能会报警告,但如果你确定这个改动,则不需要管它。如GAN等模型训练辨别器时,生成器参数只参与前向传播,不参与更新。
启动多卡训练
mp.spawn(train,args=(world_size, model_args, num_train_steps, name, seed),nprocs=world_size,join=True)
- 第一个参数为训练函数
train
,args为传递给train函数的参数,同时还会隐式传递一个进程标识参数rank给train函数 nprocs
参数指定要启动的进程数量,通常等于分布式环境中的总进程数world_size
join
参数用于指定是否要等待所有进程完成后再继续执行后续代码。设置为True
表示要等待所有进程完成,然后再继续执行主进程的代码
释放资源
dist.destroy_process_group() # 训练结束释放进程组和其他资源
注意点:
-
model_ddp = DDP(model, **ddp_kwargs)
这一步会将模型进一步包装为ddp模型,如果想要调用原来model中的方法需要加.module
,如model_ddp.module.func(...)
。 -
同上,参数保存时也会在key前加
module
,可以在写入时或读取模型时候手动修改# 载入参数时修改state_dict的key值 ck = {k.replace('module.', ''): v for k, v in torch.load('model.pth')['state_dict'].items()} model.load_state_dict(ck)
-
使用梯度累积+多卡训练时,只有更新参数时需要梯度同步。故在梯度累积阶段,使用
model_ddp.no_sync()
主动不进行梯度同步