单卡训练
# data load
train_data = MyDataset()
train_data_loader = DataLoader(dataset=train_data, batch_size=bs, shuffle=True, drop_last=True, num_workers=8, pin_memory=True)
# init network
net = MyNet().cuda()
optimizer = optim.AdamW(net.paramrters(), lr=1e-6, weight_decay=1e-3)
# load weight
checkpoint = torch.load(*model_path)
net.load_state_dict(checkpoint['state_dict']['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
# save model
states = {
'state_dict': {'net': net.state_dict()},
'optimizer': optimizer.state_dict()
}
torch.save(states, *model_save_path)
单机多卡训练(DP)
# data load
train_data = MyDataset()
train_data_loader = DataLoader(dataset=train_data, batch_size=bs, shuffle=True, drop_last=True, num_workers=8, pin_memory=True)
# init network
net = MyNet().cuda()
optimizer = optim.AdamW(net.paramrters(), lr=1e-6, weight_decay=1e-3)
# 如果模型内有bn操作,需要转换
def convert_model(module):
"""Traverse the input module and its child recursively
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
to SynchronizedBatchNorm*N*d
Args:
module: the input module needs to be convert to SyncBN model
Examples:
>>> import torch.nn as nn
>>> import torchvision
>>> # m is a standard pytorch model
>>> m = torchvision.models.resnet18(True)
>>> m = nn.DataParallel(m)
>>> # after convert, m is using SyncBN
>>> m = convert_model(m)
"""
if isinstance(module, torch.nn.DataParallel):
mod = module.module
mod = convert_model(mod)
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
return mod
mod = module
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.BatchNorm3d],
[SynchronizedBatchNorm1d,
SynchronizedBatchNorm2d,
SynchronizedBatchNorm3d]):
if isinstance(module, pth_module):
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_model(child))
return mod
net = convert_model(net)
net = nn.DataParallel(net).cuda()
多机多卡训练
# data load
# nodes表示有多少节点,假设现在有两个节点,每个节点八张GPU。则nnodes=2, nproc_per_node=8, 每个node都有一个node_rank, 节点0: node_rank=0, 节点1: node_rank=1, 总共16张卡, rank的范围是0-15, 序号从0开始一直到整个分布式中最后一个GPU的数, local_rank是每个节点内对GPU的编号,范围是0-7。
# 如果是单机多卡的机器,WORLD_SIZE代表着使用进程数量(一个进程对应一块GPU),这里RANK和LOCAL_RANK这里的数值是一样的,代表着WORLD_SIZE中的第几个进程(GPU)。
# 如果是多机多卡的机器,WORLD_SIZE代表着所有机器中总进程数(一个进程对应一块GPU),RANK代表着是在WORLD_SIZE中的哪一个进程,LOCAL_RANK代表着当前机器上的第几个进程(GPU)。
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ["RANK"])
local_rank=int(os.environ["LOCAL_RANK"])
world_size = int(os.environ['WORLD_SIZE'])
print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
else:
rank = 0
local_rank = int(os.environ["LOCAL_RANK"])
world_size = 1
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
random.seed(opt.seed+rank)
np.random.seed(opt.seed+rank)
torch.cuda.manual_seed(opt.seed+rank)
train_data = MyDataset()
# pytorch在分布式训练过程中,对于数据的读取是采用主进程预读取并缓存,然后其它进程从缓存中读取,不同进程之间的数据同步具体通过torch.distributed.barrier()实现。
torch.distributed.barrier()
train_sampler = None
if world_size > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True)
training_data_loader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size, sampler=train_sampler,drop_last=True)
# init network
net = MyNet().cuda()
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net).cuda()
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[local_rank], broadcast_buffers=False#,find_unused_parameters=True)
optimizer = optim.AdamW(net.paramrters(), lr=1e-6, weight_decay=1e-3)
# load weight
checkpoint = torch.load(*model_path)['state_dict']['net']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
if 'module' in k:
k = k[7:] # remove module.
new_state_dict[k] = v
net.load_state_dict(new_state_dict)
# train and save model
for epoch in range(10):
train_sampler.set_epoch(epoch) # 设置epoch 更新种子 保证每个epoch内见到的数据顺序不固定
for iteration, data in enumrate(training_data_loader):
'''
your code
'''
if rank == 0 and local_rank == 0:
states = {
'state_dict': {'net': net.state_dict()},
'optimizer': {'net': optimizer.state_dict()}
}
torch.save(states, model_out_path)
Tips:
多级多卡DDP训练命令参考:
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr='localhost' --master_port=12345 --use_env train.py
或者自动获取
python -m torch.distributed.launch --nproc_per_node=${KUBERNETES_CONTAINER_RESOURCE_GPU} \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \
--nnodes=${WORLD_SIZE} \
--node_rank=${RANK} \
--use_env train.py