pytorch教学:torch.nn.parallel.DistributedDataParallel(DDP分布式训练)

DDP训练大致是一个GPU开一个线程,如果有两个GPU,则将dataset分成2份,然后一个GPU读取一份 下面的代码能正确使用DDP分布式训练,直接参考即可 注:本代码只适用于单机多卡训练,多机多卡的由于资源有限还没试过 在终端的运行命令: 
 
 python -m torch.distributed.launch --nproc_per_node 2 train.py 其中2表示你有几个GPU 
 
import datetime
import os

import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms

import joint_transforms
from config import msra10k_path
from datasets import ImageFolder
from misc import AvgMeter, check_mkdir
from model import R3Net
from torch.backends import cudnn

import torch.distributed as dist   # !!!!!!!!!!!!!!!!!!!!!!!!
from torch.utils.data.distributed import DistributedSampler  # !!!!!!!!!!!!!!!!!!!

dist.init_process_group(backend='nccl', init_method='env://')  # !!!!!!!!!!!!!!!!!!!!!
batch_size = 12  # 主卡上的batchsize      # !!!!!!!!!!!!!!!!!!!!!!!!!!
data_size = 25  # 总共的batchsize   # !!!!!!!!!!!!!!!!!!!
local_rank = torch.distributed.get_rank()   #  !!!!!!!!!!!!!!!!!!!!
torch.cuda.set_device(local_rank)   #  !!!!!!!!!!!!!!!!!!!!!!!!
#dist.init_process_group(backend='nccl', init_method='env://', world_size=2, rank=local_rank)
print(local_rank) # 注意!!!!!!!!!!!!! 会先输出0  再输出1
cudnn.benchmark = True
torch.manual_seed(2018)

ckpt_path = './ckpt'
exp_name = 'R3Net/train_model'

args = {
    'iter_num': 8000,
    'train_batch_size': 10,
    'last_iter': 0,
    'lr': 1e-3,
    'lr_decay': 0.9,
    'weight_decay': 5e-4,
    'momentum': 0.9,
    'snapshot': ''
}

joint_transform = joint_transforms.Compose([
    joint_transforms.RandomCrop(300),
    joint_transforms.RandomHorizontallyFlip(),
    joint_transforms.RandomRotate(10)
])
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
target_transform = transforms.ToTensor()

train_set = ImageFolder(msra10k_path, joint_transform, img_transform, target_transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set,
                                                                num_replicas=2,
                                                                rank=local_rank)   # !!!!!!!!!!
train_loader = DataLoader(dataset=train_set,batch_size=batch_size,sampler=train_sampler)   # !!!!!!!!!!!
#train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=12, shuffle=True)

criterion = nn.BCEWithLogitsLoss().cuda()
log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')


def main():
    net = R3Net()
    net = net.cuda()
    device = torch.device('cuda:%d' % local_rank)  
    net = net.to(device)
    net = nn.parallel.DistributedDataParallel(net,
                                                device_ids=[local_rank, ],  # !!!!!!!!!!!!是个List
                                                output_device=0)  # !!!!!!!!!!!!!!!!!!!!!!
    #net.load_state_dict(torch.load('/home/yyb/pytorch_proj/R3Net/ckpt/R3Net/2020.7.3/1/12500.pth'))

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], momentum=args['momentum'])


    if len(args['snapshot']) > 0:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')
    train(net, optimizer)


def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
        loss3_record, loss4_record, loss5_record, loss6_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
        loss3_sim_record, loss5_sim_record = AvgMeter(), AvgMeter()  ##

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                                ) ** args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                            ) ** args['lr_decay']

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()  # !!!!!!!!!!!!!!!!!!!
            labels = Variable(labels).cuda()  # !!!!!!!!!!!!!!!

            optimizer.zero_grad()
            outputs0, outputs1, outputs2, outputs3, outputs4, outputs5, outputs6 = net(inputs) ##
            loss0 = criterion(outputs0, labels)
            loss1 = criterion(outputs1, labels)
            loss2 = criterion(outputs2, labels)
            loss3 = criterion(outputs3, labels)
            loss4 = criterion(outputs4, labels)
            loss5 = criterion(outputs5, labels)
            loss6 = criterion(outputs6, labels)

            total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.item(), batch_size)
            loss0_record.update(loss0.item(), batch_size)
            loss1_record.update(loss1.item(), batch_size)
            loss2_record.update(loss2.item(), batch_size)
            loss3_record.update(loss3.item(), batch_size)
            loss4_record.update(loss4.item(), batch_size)
            loss5_record.update(loss5.item(), batch_size)
            loss6_record.update(loss6.item(), batch_size)


            curr_iter += 1

            log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \
                  '[loss4 %.5f], [loss5 %.5f], [loss6 %.5f],[lr %.13f]' % \
                  (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
                   loss3_record.avg, loss4_record.avg, loss5_record.avg, loss6_record.avg,
                   optimizer.param_groups[1]['lr'])
            print(log)
            open(log_path, 'a').write(log + '\n')

            # if curr_iter == 10500:
            #     torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
            #     torch.save(optimizer.state_dict(),
            #                os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
            if curr_iter % 400 == 0:
                torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d_epoch.pth' % (curr_iter / 1250)))
                torch.save(optimizer.state_dict(),
                           os.path.join(ckpt_path, exp_name, '%d_epoch_optim.pth' % (curr_iter / 1250)))

            if curr_iter % args['iter_num'] == 0:
                torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(optimizer.state_dict(),
                           os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
            if curr_iter == args['iter_num']:
                return


if __name__ == '__main__':
    main()

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 
关于DDP的官方注释: 
class DistributedDataParallel(Module):
    r"""Implements distributed data parallelism that is based on
    ``torch.distributed`` package at the module level.

    This container parallelizes the application of the given module by
    splitting the input across the specified devices by chunking in the batch
    dimension. The module is replicated on each machine and each device, and
    each such replica handles a portion of the input. During the backwards
    pass, gradients from each node are averaged.  # 不同设备的梯度求平均
123456789 
    .. note:: If you use ``torch.save`` on one process to checkpoint the module,
        and ``torch.load`` on some other processes to recover it, make sure that
        ``map_location`` is configured properly for every process. Without
        ``map_location``, ``torch.load`` would recover the module to devices
        where the module was saved from.  # 不同设备保存模型和读取模型:map_location
12345 
    .. note::
        Parameters are never broadcast between processes. The module performs
        an all-reduce step on gradients and assumes that they will be modified
        by the optimizer in all processes in the same way. Buffers
        (e.g. BatchNorm stats) are broadcast from the module in process of rank
        0, to all other replicas in the system in every iteration. # 不同的进程之间并不广播参数
123456 
参考文献 1、关于pytorch 使用DDP模式(torch.nn.parallel.DistributedDataParallel)时,DistributedSampler(dataset)用法解释
————————————————
版权声明:本文为CSDN博主「贾小树」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/j879159541/article/details/107173029

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值