动手实践看懂深度学习的DP和DDP

摘要

为了尽可能加快训练,我们会使用两种GPU并行手段,DP和DDP,但是DP其实只会开一个进程去管理,计算资源分配不均,DDP上我们倾向于一张卡开一个进程,使得我们的计算资源能够最大化的利用。本次的文章会快速地形象地过一下DP和DDP并且告诉大家如何代码层面上实践。

DP

在这里插入图片描述
从图中我们可以看出,在forward环节,gpu1会先把所有的数据拿到,然后分发给其他的gpu,当然它自己也拿一份,接着它把自己的模型也复制成4份,每个gpu也拿一份,每个gpu自己跑自己的forward,跑完后将output传给gpu1,gpu1处理所有的output对应的梯度,然后进行backward,将要反向传播的梯度分配给其他的gpu,然后其他的gpu又各自进行自己的反向计算,计算完后将最后的梯度交给gpu1进行更新。我们可以看到,在gpu1分配任务和更新的时候,其实其他的gpu其实都是闲置的,所以利用率没法上来,全部人都得等gpu1。那么我们可不可以想一种新方法来让每个gpu自己拿到数据后,自己跑前后向,而且自己更新梯度呢?DDP这不就来了嘛!

DDP

在这里插入图片描述
秉着尽量少理论,多形象的原则,加速理解,看图。我们将我们的数据以一个一个的batch传入网络,我们有两台machine,两台machine上各有两台gpu。每台gpu上都有自己的model(都是同一个model的复制品)和optimizer。每次来一个batch的数据,我们都会让Distributed sampler去将数据分配好发给指定的gpu,然后gpu们自己跑自己的,跑完前向后,每个gpu通过DDP的后端通讯可以知道其他所有gpu跑的结果,同步了所有gpu的梯度,拿到所有的信息后就吭哧吭哧自己去反向传播更新梯度。DDP就这么简单。

DDP代码实践

在这里插入图片描述

# 1. 导包:一些需要导入的库
# 模型相关
from torch.nn.parallel import DistributedDataParallel as DDP
# 数据相关
from torch.utils.data.distributed import DistributedSampler
# ddp自身的机制相关
import torch.distributed as dist

# 2.后端多卡通讯及GPU序号(RANK)
if DDP_ON:
    init_process_group(backend="nccl")
    LOCAL_RANK = device_id = int(os.environ["LOCAL_RANK"])
    WORLD_SIZE = torch.cuda.device_count()

    device = torch.device('cuda', device_id) # note that device_id is an integer but device is a datetype.
    print(f"Start running basic DDP on rank {LOCAL_RANK}.")
    logging.info(f'Using device {device_id}')

# 3. DDP model
net = DDP(net, device_ids = [device_id], output_device=device_id)


# 4.喂数据给多卡
loader_args = dict(batch_size=batch_size, num_workers=WORLD_SIZE*4, pin_memory=True) # batchsize is for a single proc
if DDP_ON:
    train_sampler = DistributedSampler(train_set)
    train_loader = DataLoader(train_set, sampler=train_sampler, **loader_args)
else:
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    
# no need for distributed sampler for val
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)


# 5.set_epoch 防止每次数据都是一样的(如下图)
# ref: https://blog.csdn.net/weixin_41978699/article/details/121742647
for epoch in range(start, start+epochs):
    if LOCAL_RANK == 0:
        print('lr: ', optimizer.param_groups[0]['lr']) 

    net.train()
    epoch_loss = 0

    # To avoid duplicated data sent to multi-gpu
    train_loader.sampler.set_epoch(epoch)

在这里插入图片描述

启动

torchrun --nproc_per_node=4 \
          multigpu_torchrun.py \
          --batch_size 4 \
          --lr 1e-3
python -m torch.distributed.launch \
      --nproc_per_node = 4 \
        train.py \
      --batch_size 4

完整代码布局参考

import argparse
import logging
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss
from evaluate import evaluate
from unet import UNet
import os
import torch.distributed as dist

# for reproducibility
import random
import numpy as np
import torch.backends.cudnn as cudnn

# ABOUT DDP
# for model loading in ddp mode
from torch.nn.parallel import DistributedDataParallel as DDP
# for data loading in ddp mode
from torch.utils.data.distributed import DistributedSampler

import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group



def init_seeds(seed=0, cuda_deterministic=True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
    if cuda_deterministic:  # slower, more reproducible
        cudnn.deterministic = True
        cudnn.benchmark = False
    else:  # faster, less reproducible
        cudnn.deterministic = False
        cudnn.benchmark = True

def train_net(net,
              device,
              start: int = 0,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 1e-5,
              val_percent: float = 0.1,
              save_checkpoint: bool = True,
              img_scale: float = 0.5,
              amp: bool = False):
    

    if DDP_ON: # modify the net's attributes when using ddp
        net.n_channels = net.module.n_channels
        net.n_classes  = net.module.n_classes

    # 1. Create dataset
    try:
        dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
    except (AssertionError, RuntimeError):
        dataset = BasicDataset(dir_img, dir_mask, img_scale)

    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=WORLD_SIZE*4, pin_memory=True) # batchsize is for a single process(GPU)

    if DDP_ON:
        train_sampler = DistributedSampler(train_set)
        train_loader = DataLoader(train_set, sampler=train_sampler, **loader_args)
    else:
        train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    
    
    # no need for distributed sampler for val
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
    
    # (Initialize logging)
    if LOCAL_RANK == 0:
        experiment = wandb.init(project='U-Net-DDP', resume='allow', anonymous='must')
        experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
                                  val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
                                  amp=amp))
            
        logging.info(f'''Starting training:
                Epochs:          {epochs}
                Start from:      {start}
                Batch size:      {batch_size}
                Learning rate:   {learning_rate}
                Training size:   {n_train}
                Validation size: {n_val}
                Checkpoints:     {save_checkpoint}
                Device:          {device.type}
                Images scaling:  {img_scale}
                Mixed Precision: {amp}
            ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    criterion = nn.CrossEntropyLoss() 
    
    optimizer = optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=1e-8)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-7)
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    global_step = 0

    # 5. Begin training
    for epoch in range(start, start+epochs):
        if LOCAL_RANK == 0:
            print('lr: ', optimizer.param_groups[0]['lr']) 
        
        net.train()
        epoch_loss = 0

        # To avoid duplicated data sent to multi-gpu
        train_loader.sampler.set_epoch(epoch)

        disable = False if LOCAL_RANK == 0 else True

        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs+start}', unit='img', disable=disable) as pbar:
            for batch in train_loader:
                images = batch['image']
                true_masks = batch['mask']
                    
                assert images.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    loss = criterion(masks_pred, true_masks) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
                                       multiclass=True)

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()

                if LOCAL_RANK == 0:
                    experiment.log({
                        'train loss': loss.item(),
                        'step': global_step,
                        'epoch': epoch
                    })
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                division_step = (n_train // (5 * batch_size))
                if division_step > 0:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in net.named_parameters():
                            tag = tag.replace('/', '.')
                            if not torch.isinf(value).any():
                                histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                            if not torch.isinf(value.grad).any():
                                histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                        val_score = evaluate(net, val_loader, device, disable_log = disable)

                        if LOCAL_RANK == 0:
                            logging.info('Validation Dice score: {}'.format(val_score))
                            experiment.log({
                                'learning rate': optimizer.param_groups[0]['lr'],
                                'validation Dice': val_score,
                                'images': wandb.Image(images[0].cpu()),
                                'masks': {
                                    'true': wandb.Image(true_masks[0].float().cpu()),
                                    'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
                                },
                                'step': global_step,
                                'epoch': epoch,
                                **histograms
                            })
        scheduler.step()
        if save_checkpoint and LOCAL_RANK == 0 and (epoch % args.save_every == 0):
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            torch.save(net.module.state_dict(), str(dir_checkpoint / 'DDP_checkpoint_epoch{}.pth'.format(epoch)))
            
            logging.info(f'Checkpoint {epoch} saved!')


##################################### arguments ###########################################
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                    help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                    help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
parser.add_argument('--exp_name', type=str, default='hgb_exp')
parser.add_argument('--ddp_mode', action='store_true')
parser.add_argument('--save_every', type=int, default=5)
parser.add_argument('--start_from', type=int, default=0)




args = parser.parse_args()

dir_img = Path('./data/imgs/')
dir_mask = Path('./data/masks/')
dir_checkpoint = Path('./checkpoints/')

DDP_ON = True if args.ddp_mode else False

#########################################################################################

if DDP_ON:
    init_process_group(backend="nccl")
    LOCAL_RANK = device_id = int(os.environ["LOCAL_RANK"])
    WORLD_SIZE = torch.cuda.device_count()

    device = torch.device('cuda', device_id) # note that device_id is an integer but device is a datetype.
    print(f"Start running basic DDP on rank {LOCAL_RANK}.")
    logging.info(f'Using device {device_id}')


if __name__ == '__main__':
    #!highly recommended]
    # ref: pytorch org ddp tutorial 
    # 1. https://pytorch.org/tutorials/beginner/ddp_series_multigpu.html#multi-gpu-training-with-ddp
    # 2. https://pytorch.org/tutorials/beginner/ddp_series_multigpu.html
    
    init_seeds(0)
    # Change here to adapt to your data
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
    
    if LOCAL_RANK == 0:
        print(f'Network:\n'
            f'\t{net.n_channels} input channels\n'
            f'\t{net.n_classes} output channels (classes)\n'
            f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

    if args.load:
        # ref: https://blog.csdn.net/hustwayne/article/details/120324639  use method 2 with module
        # net.load_state_dict(torch.load(args.load, map_location=device))
        net.load_state_dict({k.replace('module.', ''): v for k, v in                 
                       torch.load(args.load, map_location=device).items()})

        logging.info(f'Model loaded from {args.load}')


    torch.cuda.set_device(LOCAL_RANK)
    net.to(device=device)
    # wrap our model with ddp
    net = DDP(net, device_ids = [device_id], output_device=device_id)

    try:
        train_net(net=net,
                  start=args.start_from,
                  epochs=args.epochs,
                  batch_size=args.batch_size,
                  learning_rate=args.lr,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100,
                  amp=args.amp)
    except KeyboardInterrupt:
        torch.save(net.module.state_dict(), 'INTERRUPTED_DDP.pth')
        logging.info('Saved interrupt')
        raise
    destroy_process_group()
  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

落难Coder

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值