torch.distributed补充

主要是关于dist.isend(isend)和dist.iirecv(irecv)的问题。

1:使用irecv和isend的时候,一定要给到返回值。否则通信无法进行,不知道为啥,我也很想知道原因。而send和recv就不需要。

from torch.utils.data import RandomSampler
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
import torch
import os
import time


def ceshi(rank, world_size):
    # os.environ['CUDA_VISIBLE_DEVICES'] = rank
    a = torch.randn((1,2))
    print('rank{}: \n'.format(rank), a)

    dist.init_process_group('gloo', rank = rank, world_size = world_size)
    if rank == 0:
        # req = dist.isend(a, dst=1)
        dist.irecv(a, src=1)
        # req = dist.irecv(a, src=1)
        print('rank0 接收的是:\n', a)


    else:
        # req = dist.irecv(a, src=0)
        # req = dist.isend(a, dst=0)
        # req.wait()
        dist.isend(a, dst=0)
        # dist.irecv(a, src=0)

        print('rank1 发送的是:\n', a)

if __name__ == '__main__':
    world_size = 2
    mp.set_start_method('spawn')

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    p0 = mp.Process(target= ceshi, args=(0, world_size))
    p1 = mp.Process(target= ceshi, args=(1, world_size))


    p0.start()
    p1.start()

    p0.join()
    p1.join()

结果:

但只要加上返回值,一下就好了

from torch.utils.data import RandomSampler
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
import torch
import os
import time


def ceshi(rank, world_size):
    # os.environ['CUDA_VISIBLE_DEVICES'] = rank
    a = torch.randn((1,2))
    print('rank{}: \n'.format(rank), a)

    dist.init_process_group('gloo', rank = rank, world_size = world_size)
    if rank == 0:
        # req = dist.isend(a, dst=1)
        req = dist.irecv(a, src=1)
        # req = dist.irecv(a, src=1)
        print('rank0 接收的是:\n', a)


    else:
        # req = dist.irecv(a, src=0)
        # req = dist.isend(a, dst=0)
        # req.wait()
        req = dist.isend(a, dst=0)
        # dist.irecv(a, src=0)

        print('rank1 发送的是:\n', a)

if __name__ == '__main__':
    world_size = 2
    mp.set_start_method('spawn')

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    p0 = mp.Process(target= ceshi, args=(0, world_size))
    p1 = mp.Process(target= ceshi, args=(1, world_size))


    p0.start()
    p1.start()

    p0.join()
    p1.join()

 结果:

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值