DDP训练大致是一个GPU开一个线程,如果有两个GPU,则将dataset分成2份,然后一个GPU读取一份 下面的代码能正确使用DDP分布式训练,直接参考即可 注:本代码只适用于单机多卡训练,多机多卡的由于资源有限还没试过 在终端的运行命令:
python -m torch.distributed.launch --nproc_per_node 2 train.py 其中2表示你有几个GPU
import datetime
import os
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
import joint_transforms
from config import msra10k_path
from datasets import ImageFolder
from misc import AvgMeter, check_mkdir
from model import R3Net
from torch.backends import cudnn
import torch.distributed as dist # !!!!!!!!!!!!!!!!!!!!!!!!
from torch.utils.data.distributed import DistributedSampler # !!!!!!!!!!!!!!!!!!!
dist.init_process_group(backend='nccl', init_method='env://') # !!!!!!!!!!!!!!!!!!!!!
batch_size = 12 # 主卡上的batchsize # !!!!!!!!!!!!!!!!!!!!!!!!!!
data_size = 25 # 总共的batchsize # !!!!!!!!!!!!!!!!!!!
local_rank = torch.distributed.get_rank() # !!!!!!!!!!!!!!!!!!!!
torch.cuda.set_device(local_rank) # !!!!!!!!!!!!!!!!!!!!!!!!
#dist.init_process_group(backend='nccl', init_method='env://', world_size=2, rank=local_rank)
print(local_rank) # 注意!!!!!!!!!!!!! 会先输出0 再输出1
cudnn.benchmark = True
torch.manual_seed(2018)
ckpt_path = './ckpt'
exp_name = 'R3Net/train_model'
args = {
'iter_num': 8000,
'train_batch_size': 10,
'last_iter': 0,
'lr': 1e-3,
'lr_decay': 0.9,
'weight_decay': 5e-4,
'momentum': 0.9,
'snapshot': ''
}
joint_transform = joint_transforms.Compose([
joint_transforms.RandomCrop(300),
joint_transforms.RandomHorizontallyFlip(),
joint_transforms.RandomRotate(10)
])
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
target_transform = transforms.ToTensor()
train_set = ImageFolder(msra10k_path, joint_transform, img_transform, target_transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set,
num_replicas=2,
rank=local_rank) # !!!!!!!!!!
train_loader = DataLoader(dataset=train_set,batch_size=batch_size,sampler=train_sampler) # !!!!!!!!!!!
#train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=12, shuffle=True)
criterion = nn.BCEWithLogitsLoss().cuda()
log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')
def main():
net = R3Net()
net = net.cuda()
device = torch.device('cuda:%d' % local_rank)
net = net.to(device)
net = nn.parallel.DistributedDataParallel(net,
device_ids=[local_rank, ], # !!!!!!!!!!!!是个List
output_device=0) # !!!!!!!!!!!!!!!!!!!!!!
#net.load_state_dict(torch.load('/home/yyb/pytorch_proj/R3Net/ckpt/R3Net/2020.7.3/1/12500.pth'))
optimizer = optim.SGD([
{'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
'lr': 2 * args['lr']},
{'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
'lr': args['lr'], 'weight_decay': args['weight_decay']}
], momentum=args['momentum'])
if len(args['snapshot']) > 0:
print('training resumes from ' + args['snapshot'])
net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth')))
optimizer.param_groups[0]['lr'] = 2 * args['lr']
optimizer.param_groups[1]['lr'] = args['lr']
check_mkdir(ckpt_path)
check_mkdir(os.path.join(ckpt_path, exp_name))
open(log_path, 'w').write(str(args) + '\n\n')
train(net, optimizer)
def train(net, optimizer):
curr_iter = args['last_iter']
while True:
total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
loss3_record, loss4_record, loss5_record, loss6_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
loss3_sim_record, loss5_sim_record = AvgMeter(), AvgMeter() ##
for i, data in enumerate(train_loader):
optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
) ** args['lr_decay']
optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
) ** args['lr_decay']
inputs, labels = data
batch_size = inputs.size(0)
inputs = Variable(inputs).cuda() # !!!!!!!!!!!!!!!!!!!
labels = Variable(labels).cuda() # !!!!!!!!!!!!!!!
optimizer.zero_grad()
outputs0, outputs1, outputs2, outputs3, outputs4, outputs5, outputs6 = net(inputs) ##
loss0 = criterion(outputs0, labels)
loss1 = criterion(outputs1, labels)
loss2 = criterion(outputs2, labels)
loss3 = criterion(outputs3, labels)
loss4 = criterion(outputs4, labels)
loss5 = criterion(outputs5, labels)
loss6 = criterion(outputs6, labels)
total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
total_loss.backward()
optimizer.step()
total_loss_record.update(total_loss.item(), batch_size)
loss0_record.update(loss0.item(), batch_size)
loss1_record.update(loss1.item(), batch_size)
loss2_record.update(loss2.item(), batch_size)
loss3_record.update(loss3.item(), batch_size)
loss4_record.update(loss4.item(), batch_size)
loss5_record.update(loss5.item(), batch_size)
loss6_record.update(loss6.item(), batch_size)
curr_iter += 1
log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \
'[loss4 %.5f], [loss5 %.5f], [loss6 %.5f],[lr %.13f]' % \
(curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
loss3_record.avg, loss4_record.avg, loss5_record.avg, loss6_record.avg,
optimizer.param_groups[1]['lr'])
print(log)
open(log_path, 'a').write(log + '\n')
# if curr_iter == 10500:
# torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
# torch.save(optimizer.state_dict(),
# os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
if curr_iter % 400 == 0:
torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d_epoch.pth' % (curr_iter / 1250)))
torch.save(optimizer.state_dict(),
os.path.join(ckpt_path, exp_name, '%d_epoch_optim.pth' % (curr_iter / 1250)))
if curr_iter % args['iter_num'] == 0:
torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
torch.save(optimizer.state_dict(),
os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
if curr_iter == args['iter_num']:
return
if __name__ == '__main__':
main()
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
关于DDP的官方注释:
class DistributedDataParallel(Module):
r"""Implements distributed data parallelism that is based on
``torch.distributed`` package at the module level.
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the batch
dimension. The module is replicated on each machine and each device, and
each such replica handles a portion of the input. During the backwards
pass, gradients from each node are averaged. # 不同设备的梯度求平均
123456789
.. note:: If you use ``torch.save`` on one process to checkpoint the module,
and ``torch.load`` on some other processes to recover it, make sure that
``map_location`` is configured properly for every process. Without
``map_location``, ``torch.load`` would recover the module to devices
where the module was saved from. # 不同设备保存模型和读取模型:map_location
12345
.. note::
Parameters are never broadcast between processes. The module performs
an all-reduce step on gradients and assumes that they will be modified
by the optimizer in all processes in the same way. Buffers
(e.g. BatchNorm stats) are broadcast from the module in process of rank
0, to all other replicas in the system in every iteration. # 不同的进程之间并不广播参数
123456
参考文献 1、关于pytorch 使用DDP模式(torch.nn.parallel.DistributedDataParallel)时,DistributedSampler(dataset)用法解释
————————————————
版权声明:本文为CSDN博主「贾小树」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/j879159541/article/details/107173029