DistributedDataParallel
DistributedDataParallel为pytorch分布式接口:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank,
# this should be removed if we update BatchNorm stats
broadcast_buffers=True,
当完成以上DDP后,原以为需要对model(input)后的梯度进行reduce,后来找多个博客和官网介绍后,确定是DDP model后直接进行训练。原理如下:每张卡都进行loss(每张卡维护自己的optimizer)计算,由rank=0的卡进行梯度汇总,之后broadcast到其它GPU上。
下面看官网的一例子:
import os
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def demo_checkpoint(rank, world_size):
print(f"Running DDP checkpoint example on rank {rank}.")
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# configure map_location properly
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()
# Use a barrier() to make sure that all processes have finished reading the
# checkpoint
dist.barrier()
if rank == 0:
os.remove(CHECKPOINT_PATH)
cleanup()
参考文献
- https://zhuanlan.zhihu.com/p/86441879
- 官网:https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
- https://pytorch.org/docs/stable/distributed.html