import os
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
defsetup(rank, world_size):
os.environ['MASTER_ADDR']='localhost'
os.environ['MASTER_PORT']='12355'# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)defcleanup():
dist.destroy_process_group()
下面构建一个简单的模型,用过DDP包裹,导入一些随机生成的数据。
classToyModel(nn.Module):def__init__(self):super(ToyModel, self).__init__()
self.net1 = nn.Linear(10,10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10,5)defforward(self, x):return self.net2(self.relu(self.net1(x)))defdemo_basic(rank, world_size):print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20,10))
labels = torch.randn(20,5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()defrun_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
defdemo_checkpoint(rank, world_size):print(f"Running DDP checkpoint example on rank {rank}.")
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
CHECKPOINT_PATH = tempfile.gettempdir()+"/model.checkpoint"if rank ==0:# All processes should see same parameters as they all start from same# random parameters and gradients are synchronized in backward passes.# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)# Use a barrier() to make sure that process 1 loads the model after process# 0 saves it.
dist.barrier()# configure map_location properly
map_location ={'cuda:%d'%0:'cuda:%d'% rank}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20,10))
labels = torch.randn(20,5).to(rank)
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()# Not necessary to use a dist.barrier() to guard the file deletion below# as the AllReduce ops in the backward pass of DDP already served as# a synchronization.if rank ==0:
os.remove(CHECKPOINT_PATH)
cleanup()
4.3. 通过DDP实现模型并行
首先,定义一个模型并行的Module。
classToyMpModel(nn.Module):def__init__(self, dev0, dev1):super(ToyMpModel, self).__init__()
self.dev0 = dev0
self.dev1 = dev1
self.net1 = torch.nn.Linear(10,10).to(dev0)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10,5).to(dev1)defforward(self, x):
x = x.to(self.dev0)
x = self.relu(self.net1(x))
x = x.to(self.dev1)return self.net2(x)