pytorch分布式训练DDP代码模板

当有多卡可用时,可以考虑用 pytorch DDP 用多卡同时训练。本篇记录一种使用 DDP 训练的代码模板,要点包括:

  • 兼容 DDP 多卡和普通单卡训练;
  • DDP 起始(torch.distributed.init_process_group 及相关)和结束(torch.distributed.destroy_process_group);
  • DataLoadersampler
  • torch.nn.parallel.DistributedDataParallel 打包模型;
  • 用 rank 0 进程 save、load 用 DDP 打包后的模型;
  • 用 rank 0 进程 validate。

Code

  • 参考 [1]
import argparse, os, json, datetime, socket
import torch
import torch.nn as nn
from torchvision import transforms
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp

def free_port():
    """find an available port
    Ref: https://www.cnblogs.com/mayanan/p/15997892.html
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp:
        tcp.bind(("", 0))
        _, port = tcp.getsockname()
    return port

def main(gpu_id, world_size, args):
    # initialise DDP
    if args.ddp:
        args.gpu = gpu_id # 本进程用的本机 gpu id
        args.rank = args.rank * world_size + gpu_id
        print("rank:", args.rank, ", gpu id:", gpu_id)
        dist.init_process_group(
            backend=args.backend,
            init_method=f"tcp://localhost:{args.port}",
            rank=args.rank,
            world_size=world_size,
            # timeout=datetime.timedelta(minutes=10) # 自定义最大等待时间
        )
        torch.cuda.set_device(args.gpu)
        # dist.barrier()
        device = torch.device(f'cuda:{args.rank}')
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # augmentations
    train_trfm = transforms.Compose([
        # ...augmentations for training...
    ])
    val_trfm = transforms.Compose([
        # ...augmentations for validation...
    ])

    # datasets
    train_ds = MyDataset(train_trfm)
    val_ds = MyDataset(val_trfm)

    # samplers & data loaders
    if args.ddp:
        # DDP 要在 sampler 处指定 shuffle、drop_last
        train_sampler = dist.DistributedSampler(
            train_ds, num_replicas=world_size, rank=args.rank, shuffle=True, drop_last=True)
        val_sampler = dist.DistributedSampler(
            val_ds, num_replicas=world_size, rank=args.rank, shuffle=False, drop_last=False)

        # 而在 data loader 处**不**能指定 shuffle (即必须 shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_ds, batch_size=args.batch_size, num_workers=4, pin_memory=torch.cuda.is_available(), sampler=train_sampler)
        val_loader = torch.utils.data.DataLoader(
            val_ds, batch_size=32, num_workers=4, pin_memory=torch.cuda.is_available(), sampler=val_sampler)
    else:
        train_loader = torch.utils.data.DataLoader(
            train_ds, batch_size=args.batch_size, num_workers=4, pin_memory=torch.cuda.is_available(), shuffle=True, drop_last=True)
        val_loader = torch.utils.data.DataLoader(
            val_ds, batch_size=32, num_workers=4, pin_memory=torch.cuda.is_available(), shuffle=True, drop_last=True)


    model = MyModel().to(device)

    # DDP 打包模型
    if args.ddp:
        model = DDP(model, device_ids=[args.gpu])


    # 损失、优化器照常写
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    start_epoch = global_step = 0


    # resume from checkpoint
    if args.resume:
        assert osp.isfile(args.resume), args.resume
        if args.ddp:
            ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(dist.get_rank()))
            # Barrier before loading to ensure the file is completely written
            dist.barrier()
            # Broadcast the checkpoint from rank 0 to all other processes to ensure consistency
            ckpt = dist.broadcast_object_list([ckpt], src=0)[0]
            # DDP 模型要用 `model.module`
            model.module.load_state_dict(ckpt['model'])
        else:
            ckpt = torch.load(args.resume)
            model.load_state_dict(ckpt['model'])

        # 其它照常读
        optimizer.load_state_dict(ckpt['optimizer'])
        start_epoch = ckpt['epoch'] + 1
        global_step = ckpt['global_step'] + 1


    # training
    for epoch in range(start_epoch, args.epoch):
        print('\t', epoch, end='\r')
        if args.ddp:
            train_sampler.set_epoch(epoch)
            val_sampler.set_epoch(epoch)

        model.train()
        for i, (x, y) in enumerate(train_loader):
            print(i, end='\r')
            pred = model(x.to(device))
            loss = criterion(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1

        # saving checkpoint & validating
        # 只在 rank 0 进程操作,避免重复
        if 0 == args.rank:
            # epoch ckpt
            sd = {
                "epoch": epoch,
                "global_step": global_step,
                "optimizer": optimizer.state_dict(),
            }
            if args.ddp:
                # DDP 要用 `model.module`
                sd['model'] = model.module.state_dict()
            else:
                sd['model'] = model.state_dict()
            torch.save(sd, os.path.join(args.log_path, f"checkpoint-{epoch}.pth"))
            # Barrier after saving to ensure all processes wait until saving is complete
            # dist.barrier()


            # validation
            model.eval()
            for i, (x, y) in enumerate(val_loader):
                with torch.no_grad():
                    pred = model(x.to(device))
                acc = # ...metric calculation...

    # finish training
    if args.ddp:
        # DDP ending
        dist.destroy_process_group()


if '__main__' == __name__:
    parser = argparse.ArgumentParser()
    parser.add_argument('dataset', type=str)
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--epoch', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--log_path', type=str, default="log")
    parser.add_argument('--resume', type=str, default="", help="checkpoint to resume")
    # DDP
    parser.add_argument('--ddp', action="store_true")
    parser.add_argument('--rank', type=int, default=0)
    parser.add_argument('--backend', type=str, default="nccl", choices=["nccl", "gloo", "mpi"])
    parser.add_argument('--port', type=int, default=10000)
    args = parser.parse_args()

    # save config
    os.makedirs(args.log_path, exist_ok=True)
    with open(os.path.join(args.log_path, "config.json"), "w") as f:
        json.dump(args.__dict__, f, indent=1)

    if args.ddp:
        world_size = torch.cuda.device_count()
        print("world size:", world_size)
        args.port = free_port() # pick 1 available port
        mp.set_start_method('spawn')
        mp.spawn(
            main,
            args=(world_size, args),
            nprocs=world_size,
            join=True
        )
    else:
        main(0, 1, args)

Test DDP Backend

PyTorch DDP 有三种自带的 backend:ncclgloompi,前两种比较常用。ChatGPT 给了一个测试本机支持哪种 backend 的程序:

import torch.distributed as dist
import os

def free_port():
    """(见前文)"""

def check_backend(backend):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(free_port())
    try:
        dist.init_process_group(backend, rank=0, world_size=1)
        print(f"{backend} backend initialized successfully.")
        dist.destroy_process_group()
    except Exception as e:
        print(f"Failed to initialize {backend} backend: {e}")

if __name__ == "__main__":
    backends = ["nccl", "gloo", "mpi"]
    for backend in backends:
        check_backend(backend)

References

  1. snap-research/MoCoGAN-HD
  2. pytorch分布式卡住
  3. python获取可用端口号
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个简单的 PyTorch 分布式代码示例,使用了 PyTorch 内置的分布式工具箱(`torch.distributed`): ```python import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP # 初始化分布式环境 dist.init_process_group(backend='nccl', init_method='env://') # 定义网络 model = torch.nn.Sequential( torch.nn.Linear(10, 100), torch.nn.ReLU(), torch.nn.Linear(100, 1) ) # 将模型分布式并行化 model = DDP(model) # 定义损失函数和优化器 criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.001) # 加载数据 train_dataset = torch.utils.data.TensorDataset(torch.randn(100, 10), torch.randn(100, 1)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16) # 训练模型 for epoch in range(10): for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() # 清理分布式环境 dist.destroy_process_group() ``` 在此示例中,我们首先通过 `dist.init_process_group` 初始化了分布式环境,然后定义了一个简单的全连接神经网络模型。我们使用 `torch.nn.parallel.DistributedDataParallel` 将模型分布式并行化,这使得模型可以在多个 GPU 上运行,并实现了自动梯度求解和参数同步。接下来,我们定义了损失函数和优化器。在加载数据后,我们使用 PyTorch 的标准训练循环对模型进行训练。最后,我们使用 `dist.destroy_process_group()` 清理分布式环境。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值