自用分布式训练模版

自用DDP模版
参考:
Pytorch - DDP教程
Pytorch - torchrun 弹性启动

import argparse
import logging
import os
import sys
from importlib import reload

import datasets
import torch
import torch.distributed as dist
import torch.nn as nn
from einops import rearrange
from torch.distributed.elastic.multiprocessing.errors import record
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.adamw import AdamW
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms


class CustomModel(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(
            qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
        )
        k = k.softmax(dim=-1)
        context = torch.einsum("bhdn,bhen->bhde", k, v)
        out = torch.einsum("bhde,bhdn->bhen", context, q)
        out = rearrange(
            out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
        )
        return self.to_out(out)

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

@record
def main(custom_args):
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    print(f"Start running basic DDP example on rank {rank}.")

    # create model and move it to GPU with id rank
    gpu = rank % torch.cuda.device_count()
    world_size = dist.get_world_size()
    torch.cuda.set_device(gpu)
    torch.backends.cudnn.benchmark = True
    # ...
    dataset = Dataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) if (world_size > 1) else None
    dataloader = DataLoader(
        dataset, 
        sampler=sampler
    )
    rank_iter = iter(dataloader) 
    model=ToyModel().cuda()
    ddp_model = DDP(model, device_ids=[gpu])
    loss_fn = nn.MSELoss()
    optimizer = AdamW(ddp_model.parameters())
    optimizer.zero_grad()
    
    # train
    # batch = next(rank_iter)
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(gpu)
    loss = loss_fn(outputs, labels)
    loss.backward()
    optimizer.step()
    
    # A step finished
    # loss = dist.all_reduce(loss, op=torch.distributed.SUM)
    # print(loss / world_size)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Argparser for configuring [code base name to think of] codebase"
    )
    parser.add_argument("--cfg", type=str, default="config.yaml")
    args = parser.parse_args()
    main(args)
  • 7
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值