torch分布式之组内通all_gather通信

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
#init process_group
torch.distributed.init_process_group(backend="nccl")
world_size = torch.distributed.get_world_size()
local_rank = torch.distributed.get_rank()

torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
data = local_rank+torch.tensor(1.0)
data =data.to(device)
#同步
torch.cuda.synchronize()
print("gpu_id:",local_rank,"data:",data.cpu().detach())

#新建分组/在新分组之间进行all_gather通信,并打印结果
op = torch.distributed.new_group(ranks=[0,1])
merge_data = [torch.zeros_like(data) for _ in range(2)]
torch.distributed.all_gather(merge_data, data,group=op)
if local_rank==0 or local_rank==1:
    print(merge_data)
    
#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7  torchrun --nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr 127.0.0.1 --master_port 19699 test_dist.py

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值