pytorch使用DistributedDataParallel 进行ResNet分布式训练

Pytorch使用DistributedDataParallel 进行ResNet分布式训练

前言

网上有很多DistributedDataParallel的教程,但是都是一些代码的片段,使得实际规整起来也很有难度,Pytorch官方给的例子也没有数据处理部分的代码,下面是本人经过研究综合得出来的用法,如果有什么不对的地方,欢迎大家指出.

pytorch官方的demo

https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

1. 定义resnet的网络结构

resnet.py

from torch import nn
from torch.nn import functional as F


class resblock(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super(resblock, self).__init__()
        self.conv_1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn_1 = nn.BatchNorm2d(ch_out)
        self.conv_2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn_2 = nn.BatchNorm2d(ch_out)
        self.ch_in, self.ch_out, self.stride = ch_in, ch_out, stride
        self.ch_trans = nn.Sequential()
        if ch_in != ch_out:
            self.ch_trans = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                                          nn.BatchNorm2d(self.ch_out))
        # ch_trans表示通道数转变。因为要做short_cut,所以x_pro和x_ch的size应该完全一致

    def forward(self, x):
        x_pro = F.relu(self.bn_1(self.conv_1(x)))
        x_pro = self.bn_2(self.conv_2(x_pro))

        # short_cut:
        x_ch = self.ch_trans(x)
        out = x_pro + x_ch
        return out


class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64))
        self.block1 = resblock(64, 128, 2)  # 长宽减半 32/2=16
        self.block2 = resblock(128, 256, 2)  # 长宽再减半 16/2=8
        self.block3 = resblock(256, 512, 1)
        self.block4 = resblock(512, 512, 1)
        self.outlayer = nn.Linear(512, 10)  # 512*8*8=32768

    def forward(self, x):
        x = F.relu(self.conv_1(x))
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = F.adaptive_avg_pool2d(x, [1, 1])
        x = x.reshape(x.size(0), -1)
        result = self.outlayer(x)
        return result


if __name__ == '__main__':
    net = ResNet()

2. 定义DistributedDataParallel 用到的必要工具

distribute_tools.py

其中
setup是每一个线程开启时候必须执行的.
cleanup是线程运行结束以后执行
run_distribute是用于开启线程入口

import os
import torch.distributed as dist
import torch.multiprocessing as mp


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


def run_distribute(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)

3. 训练部分

train_distribute.py

"""
使用DistributedDataParallel
进行ResNet分布式训练
"""
import pandas as pd
import torch
from torch import nn, optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from resnet import ResNet
from distribute_tools import setup, cleanup, run_distribute
from torch.nn.parallel import DistributedDataParallel



def main_worker(rank, world_size):
    """
    这里定义每个线程需要执行的任务
    :param rank: 线程号,mp.spawn会自动传进来
    :param world_size: 节点数*每个节点控制的GPU数量,这个有两个节点,每个节点控制一张卡, 所以 2*1=2
    :return:
    """
    print(f"Running basic DistributedDataParallel example on rank {rank}.")
    setup(rank, world_size)

    batchsz = 128
    normalize_op = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
    to_tensor_op = transforms.ToTensor()

    cifar_train = datasets.CIFAR10(r'./data',
                                   train=True,
                                   transform=transforms.Compose([
                                       to_tensor_op,
                                       normalize_op
                                       ]),
                                   download=True)
    # sampler是用于分发数据,如果是两张卡,数据机会对半分
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        cifar_train,
        num_replicas=world_size,
        rank=rank
    )
    # 要添加sampler用于分发数据,添加以后,如果有2张卡,step就会变成原来的1/2
    cifar_train = DataLoader(dataset=cifar_train,
                             batch_size=batchsz,
                             shuffle=False,
                             num_workers=0,
                             pin_memory=True,
                             sampler=train_sampler
                             )


    cifar_test = datasets.CIFAR10(r'./data',
                                  train=False,
                                  transform=transforms.Compose([
                                      to_tensor_op,
                                      normalize_op]),
                                  download=True)
    cifar_test = DataLoader(cifar_test,
                            batch_size=batchsz,
                            shuffle=True)


    model = ResNet().to(rank)
    model = DistributedDataParallel(model, device_ids=[rank])

    criteon = nn.CrossEntropyLoss().to(rank)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # 收集训练信息
    res_data = {"epoch": [], "loss": [], "acc": []}
    for epoch in range(100):
        train_sampler.set_epoch(epoch)
        model.train()
        for batchidx, (x, label) in enumerate(cifar_train):
            x, label = x.to(rank), label.to(rank)
            logits = model(x)
            loss = criteon(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if batchidx % 50 == 0:
                print(f"rank: {rank}, epoch: {epoch}, step: {batchidx}, loss: {loss.item()}")
        # 在线程0中进行验证
        if rank == 0:
            model.eval()
            with torch.no_grad():
                total_correct = 0
                total_num = 0
                for x, label in cifar_test:
                    x, label = x.to(rank), label.to(rank)
                    logits = model(x)
                    pred = logits.argmax(dim=1)
                    correct = torch.eq(pred, label).float().sum().item()
                    total_correct += correct
                    total_num += x.size(0)
                acc = total_correct / total_num
                print(f"epoch: {epoch}, acc: {acc}")
            res_data["epoch"].append(epoch)
            res_data["loss"].append(loss.item())
            res_data["acc"].append(acc)
            df = pd.DataFrame(res_data)
            df.to_csv("pytorch_res.csv", index=False)
    cleanup()


if __name__ == '__main__':
    import time

    start_time = time.time()
    run_distribute(main_worker, 2)
    finish_time = time.time()
    print("total time cost: {} s".format(finish_time-start_time))

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值