DDP多GPU训练模型

训练代码采用了DDP,并是用torchrun来保证训练过程异常退出时,能够根据保存的模型接着训练。
训练代码:

import cifar10DataLoader as datasets
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
import torchvision
import torch
import copy
import time
from tqdm import tqdm
from resnet import ResNetBase
from config import Config
from torch import nn, optim 

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

# 初始化DDP
def ddp_setup():
    init_process_group(backend="nccl")

# 多GPU训练
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
# device = torch.device("cuda")
class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        test_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        save_every: int,
        snapshot_path: str
    ) -> None:
        # 单机多GPU时时是用
        self.lock_rank = int(os.environ["LOCAL_RANK"])
        # 多机多GPU是用
        # self.global_rank = int(os.environ["RANK"])
        # 把模型放到GPU上
        self.module = model.to(self.lock_rank)
        self.train_data = train_data
        self.test_data = test_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.epochs_run = 0
        if os.path.exists(snapshot_path):
            print("Loading snapshot")
            self._load_snapshot(snapshot_path)
        # 将模型交给DDP进行管理
        self.module = DDP(self.module, device_ids=[self.lock_rank])
    
    # 训练异常退出时,调用此函数接着训练之前的模型
    def _load_snapshot(self, snapshot_path):
        snapshot = torch.load(snapshot_path)
        self.module.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    
    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.module(source)
        loss = torch.nn.CrossEntropyLoss()(output, targets)
        loss.backward()
        # loss.to(device)
        self.optimizer.step()
    
    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f'[GPU{self.lock_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}')
        for source, targets in tqdm(self.train_data):
            # 将训练图片与标签均放入GPU中
            source = source.to(self.lock_rank)
            targets = targets.to(self.lock_rank)
            self._run_batch(source, targets)

    def _save_snapshot(self, epoch):
        # 使用了DDP
        snapshot = {}
        # DDP的模型保存在self.module.module.state_dict()中
        snapshot["MODEL_STATE"] = copy.deepcopy(self.module.module.state_dict())
        snapshot["EPOCHS_RUN"] = epoch
        torch.save(snapshot, 'snapshot.pt')
        print(f'Epoch {epoch} | Training checkpoint saved at snapshot.pt')

    def train(self, max_epochs: int):
        for epoch in range(self.epochs_run, max_epochs):
            self._run_epoch(epoch)
            if self.lock_rank == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)

def prepare_trainData(batch_size: int):
    dataset = datasets.CIFAR10_IMG('./data', train=True, transform=transforms.Compose([
        transforms.ToTensor(),
        # 先四周填充0,在把图像随机裁剪成32*32
        transforms.RandomCrop(32, padding=4),
        # 以0.5的概率水平翻转图片
        transforms.RandomHorizontalFlip(p=0.5),
        # 均值,标准差
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]))
    # 使用了DDP,需要修改sampler,同时将shuffle设为False
    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        pin_memory=True, 
        shuffle=False,
        sampler=DistributedSampler(dataset))

def prepare_testData(batch_size: int):
    dataset = datasets.CIFAR10_IMG('./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]))
    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        pin_memory=True, 
        shuffle=False,
        sampler=DistributedSampler(dataset))
    
    
def load_train_objs():
    base = ResNetBase(Config.n_blocks, Config.n_channels, Config.bottlenecks, Config.first_kernel_size)
    classification = nn.Linear( Config.n_channels[-1], 10 )
    model = nn.Sequential( base, classification )
    # model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return model, optimizer

def main(save_every: int, total_epochs: int, snapshot_path: str="snapshot.pt"):
    # 调用函数对DDP进行初始化
    ddp_setup()
    model, optimizer = load_train_objs()
    train_data = prepare_trainData(batch_size=256)
    test_data = prepare_testData(batch_size=256)
    trainer = Trainer(model, train_data, test_data, optimizer, save_every, snapshot_path)
    trainer.train(total_epochs)
    destroy_process_group()


if __name__ == '__main__':
    import sys
    total_epochs = int(sys.argv[1])
    save_every = int(sys.argv[2])
    world_size = torch.cuda.device_count()
    print(world_size)
    # torch.cuda.set_device(world_size)
    main(save_every, total_epochs)    

#  torchrun --standalone --nproc_per_node=gpu cifarDDPnew.py 50 5 

# 此命令可以指定哪些GPU进行训练
# CUDA_VISIBLE_DEVICES=2,3 torchrun --standalone --nproc_per_node=gpu cifarDDPnew.py 100 5

模型代码:

from typing import List, Optional

import torch
from torch import nn 
from typing import Any, TypeVar, Iterator, Iterable, Generic
class Module(torch.nn.Module):
    r"""
    Wraps ``torch.nn.Module`` to overload ``__call__`` instead of
    ``forward`` for better type checking.
    
    `PyTorch Github issue for clarification <https://github.com/pytorch/pytorch/issues/44605>`_
    """

    def _forward_unimplemented(self, *input: Any) -> None:
        # To stop PyTorch from giving abstract methods warning
        pass

    def __init_subclass__(cls, **kwargs):
        if cls.__dict__.get('__call__', None) is None:
            return

        setattr(cls, 'forward', cls.__dict__['__call__'])
        delattr(cls, '__call__')

    @property
    def device(self):
        params = self.parameters()
        try:
            sample_param = next(params)
            return sample_param.device
        except StopIteration:
            raise RuntimeError(f"Unable to determine"
                               f" device of {self.__class__.__name__}") from None


class ShortcutProjection(Module):
    def __init__(self, in_channels: int, out_channels:int, stride:int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
    
    def forward(self, x:torch.Tensor):
        return self.bn(self.conv(x))
    
class ResidualBlock(Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        if stride != 1 or in_channels != out_channels:
            self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
        else:
            self.shortcut = nn.Identity()
        
        self.act2 = nn.ReLU()

    def forward(self, x: torch.Tensor):
        shortcut = self.shortcut(x)
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return self.act2(x + shortcut)

class BottleneckResidualBlock(Module):
    def __init__(self, in_channels: int, bottleneck_channels: int, out_channels: int, stride: int):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(bottleneck_channels)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(bottleneck_channels)
        self.act2 = nn.ReLU()
        self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, stride=1)
        self.bn3 = nn.BatchNorm2d(out_channels)
        if stride != 1 or in_channels != out_channels:
            self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
        else:
            self.shortcut = nn.Identity()
        self.act3 = nn.ReLU()
    
    def forward(self, x:torch.Tensor):
        shortcut = self.shortcut(x)
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        return self.act3(x + shortcut)
    

class ResNetBase(Module):
    def __init__(self, n_blocks:List[int], n_channels: List[int],
                 bottelnecks: Optional[List[int]] = None,
                 img_channels: int = 3, first_kernel_size: int = 7):
        super().__init__()
        assert len(n_blocks) == len(n_channels)
        assert bottelnecks is None or len(bottelnecks) == len(n_channels)

        # // 向小取整
        self.conv = nn.Conv2d(img_channels, n_channels[0], 
                              kernel_size=first_kernel_size, stride=2, padding=first_kernel_size // 2)
        
        self.bn = nn.BatchNorm2d(n_channels[0])

        blocks = []

        prev_channels = n_channels[0]

        # enumerate枚举索引与值
        for i, channels in enumerate(n_channels):
            # 第一个stride为2,其他的为1
            stride = 2 if len(blocks) == 0 else 1

            if bottelnecks is None:
                # 不需要使用bottelnecks时,ResidualBlock:[3, 3]
                blocks.append(ResidualBlock(prev_channels, channels, stride=stride))
            else:
                # 后面的为BottleneckResidualBlock:[1, 3, 1]
                blocks.append(BottleneckResidualBlock(prev_channels, bottelnecks[i], channels,
                                                      stride=stride))
            prev_channels = channels

            # 需要多少个blocks
            for _ in range(n_blocks[i]-1):
                if bottelnecks is None:
                    blocks.append(ResidualBlock(channels, channels, stride=1))
                else:
                    blocks.append(BottleneckResidualBlock(channels, bottelnecks[i], channels, stride=1))
        
        self.blocks = nn.Sequential(*blocks)
    
    def forward(self, x:torch.Tensor):
        x = self.bn(self.conv(x))
        x = self.blocks(x)
        x = x.view(x.shape[0], x.shape[1], -1)
        return x.mean(dim=-1)

网络配置代码:

 class Config:
    n_channels = [16, 32, 64]
    bottlenecks = [8, 16, 16]
    n_blocks = [6, 6, 6]
    first_kernel_size = 3
    total_epoches = 500
    batch_size = 256
    Lr = 0.1

数据集代码:

import json
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from typing import Any, Callable, Optional, Tuple
# 继承Dataset类
class CIFAR10_IMG(Dataset):

    def __init__(self, root, train=True, transform = None, target_transform = None) -> None:
        super().__init__()
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        # 加载训练集
        if self.train:
            file_annotation = root + '/annotations/cifar10_train.json'
            img_folder = root + '/train_cifar10/'
        # 加载测试集
        else:
            file_annotation = root + '/annotations/cifar10_test.json'
            img_folder = root + '/test_cifar10/'
        # 读取json文件
        fp = open(file_annotation, 'r')
        data_dict = json.load(fp)
        # 图片数和标签数不匹配说明数据集标注有问题,报错
        assert len(data_dict['images'])==len(data_dict['categories'])
        num_data = len(data_dict['images'])
        # 读取图片与对应的标注
        self.filenames = []
        self.labels = []
        self.img_folder = img_folder
        for i in range(num_data):
            self.filenames.append(data_dict['images'][i])
            self.labels.append(data_dict['categories'][i])

    def __getitem__(self, index):
        img_name = self.img_folder + self.filenames[index]
        label = self.labels[index]
        # 将数据转换为numpy格式
        img = plt.imread(img_name)
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.filenames)    
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CRE_MO

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值