import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
#init process_group
torch.distributed.init_process_group(backend="nccl")
world_size = torch.distributed.get_world_size()
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
data = local_rank+torch.tensor(1.0)
data =data.to(device)
#同步
torch.cuda.synchronize()
print("gpu_id:",local_rank,"data:",data.cpu().detach())
#新建分组/在新分组之间进行all_gather通信,并打印结果
op = torch.distributed.new_group(ranks=[0,1])
merge_data = [torch.zeros_like(data) for _ in range(2)]
torch.distributed.all_gather(merge_data, data,group=op)
if local_rank==0 or local_rank==1:
print(merge_data)
#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr 127.0.0.1 --master_port 19699 test_dist.py
04-29
5986

07-13
3008
