torch.nn.parallel.DistributedDataParallel
from torch.nn.parallel import DistributedDataParallel
torch.distributed.init_process_group(backend="nccl")
model = model.cuda()
model = DistributedDataParallel(model)
运行命令行
python3 -m torch.distributed.launch main.py
DEBUG
- 如果使用 argparse, 要添加参数
--local_rank
- 似乎代码中使用 Embedding 会出错
- 详细信息请移步 https://zhuanlan.zhihu.com/p/86441879