DistributedDataParallel多显卡训练模板

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
import torch
import os
import argparse
from torchvision.models import resnet18
from torch.optim import Adam
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from datetime import datetime

class MyDataset(Dataset):
    def __init__(self,data_root):
        # 初始化数据集
        pass
    
    def __len__(self):
        # 返回数据集总长度
        return 1000
    
    def __getitem__(self, index):
        image = torch.rand(3,64,64)
        label = torch.randint(0,100,(1,1))[0].float()
        return image,label

"""
    多进程环境下的梯度缩减(归约操作),以确保各进程间的梯度一致。
"""
def reduce_loss(tensor, rank, world_size):
    with torch.no_grad():
        dist.reduce(tensor, dst=0)
        if rank == 0:
            tensor /= world_size

"""
    执行训练任务的核心函数,每个进程都运行这个函数
"""
def main_worker(index, world_size, config):
    # 初始设置把当前进程分配给一个CUDA设备,后续之间通过 .cuda() 即可
    torch.cuda.set_device(index)

    # 初始化分布式环境的函数
    # backend="nccl" 指定了用于跨多个GPU进行通信的后端
    # init_method="env://" 分布式环境时通信的方式 "env://"意味着使用环境变量来传递用于初始化分布式训练环境的相关信息
    dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=index)
    now = datetime.now()

    # 模型保存路径
    save_path = os.path.join(config.save_path,now.strftime('train_%Y-%m-%d_%H-%M-%S'))
    os.makedirs(save_path,exist_ok=True)

    model = resnet18().cuda()
    # 加载预训练模型
    if config.pretrain != None:
        model.load_state_dict(torch.load(config.pretrain,map_location={'cuda:%d' % 0: 'cuda:%d' % dist.get_rank()}))
    # 处理预训练权重,使用 SyncBatchNorm 同步不同进程的BatchNorm层,然后用 DistributedDataParallel 包装模型实现多进程并行训练。
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    # find_unused_parameters=True DDP会遍历模型所有的参数,以便发现在前向传播过程中未参与计算图的参数
    model = DistributedDataParallel(model, device_ids=[index],find_unused_parameters=True)
    model.train()

    optimizer = Adam(model.parameters(), lr=0.0001, betas=(0, 0.999), weight_decay=1e-4)
    mse_loss = torch.nn.MSELoss()

    dataset = MyDataset(config.data_root)
    sampler = DistributedSampler(dataset)
    dataloader = DataLoader(dataset, batch_size=config.batch, shuffle=False, num_workers=8, pin_memory=True, drop_last=False, sampler=sampler)

    for epoch in range(config.epochs):
        # 训练进度条
        tqdm_bar = tqdm(dataloader,desc="训练_显卡{}".format(index))
        for data in tqdm_bar:
            image,label = data[0].cuda(),data[1].cuda()
            optimizer.zero_grad()
            output = model(image)
            loss = mse_loss(output,label)
            loss.backward()
            optimizer.step()
            tqdm_bar.set_postfix(loss=loss.item())
        # 保存模型参数
        if epoch % config.save_epoch == 0 and epoch != 0 and index == 0:
            torch.save(model.module.state_dict(),os.path.join(save_path,'epoch_{}.pth'.format(str(epoch))))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--port', type= str,  default='29500', help="用于建立连接的主节点端口号")
    parser.add_argument('--size', type=str,  default='2', help="分布式训练中参与的总进程数量")
    parser.add_argument('--visible', type=str,  default="0,1",help="哪些GPU是可见的")
    parser.add_argument('--address', type=str,  default="localhost",help="设置分布式训练的主节点地址")
    parser.add_argument('--epochs', type=int,  default=100, help="训练轮次数")
    parser.add_argument('--batch', type=int,  default=32,help="batch size")
    parser.add_argument('--pretrain', type=str,  default=None, help="预训练模型路径")
    parser.add_argument('--save_path', type=str,  default="./checkpoint" ,help="模型保存目录")
    parser.add_argument('--save_epoch', type=int,  default=10, help="每迭代多少论保存模型")
    parser.add_argument('--data_root', type=int,  default=10,help="数据集路径")
    args = parser.parse_args()
    print(args)
    os.environ['MASTER_ADDR'] = args.address
    os.environ['MASTER_PORT'] = args.port
    os.environ['WORLD_SIZE'] = args.size
    os.environ['RANK'] = '0' # 主节点的序号
    os.environ["CUDA_VISIBLE_DEVICES"] = args.visible
    world_size = int(args.size)
    

    mp.spawn(main_worker, args=(world_size,args), nprocs=world_size, join=True)

整理一个通用模板,省的每次需要的时候再找,代码应该是可以直接运行的,就是损失老大了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值