主要是关于dist.isend(isend)和dist.iirecv(irecv)的问题。
1:使用irecv和isend的时候,一定要给到返回值。否则通信无法进行,不知道为啥,我也很想知道原因。而send和recv就不需要。
from torch.utils.data import RandomSampler
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
import torch
import os
import time
def ceshi(rank, world_size):
# os.environ['CUDA_VISIBLE_DEVICES'] = rank
a = torch.randn((1,2))
print('rank{}: \n'.format(rank), a)
dist.init_process_group('gloo', rank = rank, world_size = world_size)
if rank == 0:
# req = dist.isend(a, dst=1)
dist.irecv(a, src=1)
# req = dist.irecv(a, src=1)
print('rank0 接收的是:\n', a)
else:
# req = dist.irecv(a, src=0)
# req = dist.isend(a, dst=0)
# req.wait()
dist.isend(a, dst=0)
# dist.irecv(a, src=0)
print('rank1 发送的是:\n', a)
if __name__ == '__main__':
world_size = 2
mp.set_start_method('spawn')
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
p0 = mp.Process(target= ceshi, args=(0, world_size))
p1 = mp.Process(target= ceshi, args=(1, world_size))
p0.start()
p1.start()
p0.join()
p1.join()
结果:
但只要加上返回值,一下就好了
from torch.utils.data import RandomSampler
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
import torch
import os
import time
def ceshi(rank, world_size):
# os.environ['CUDA_VISIBLE_DEVICES'] = rank
a = torch.randn((1,2))
print('rank{}: \n'.format(rank), a)
dist.init_process_group('gloo', rank = rank, world_size = world_size)
if rank == 0:
# req = dist.isend(a, dst=1)
req = dist.irecv(a, src=1)
# req = dist.irecv(a, src=1)
print('rank0 接收的是:\n', a)
else:
# req = dist.irecv(a, src=0)
# req = dist.isend(a, dst=0)
# req.wait()
req = dist.isend(a, dst=0)
# dist.irecv(a, src=0)
print('rank1 发送的是:\n', a)
if __name__ == '__main__':
world_size = 2
mp.set_start_method('spawn')
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
p0 = mp.Process(target= ceshi, args=(0, world_size))
p1 = mp.Process(target= ceshi, args=(1, world_size))
p0.start()
p1.start()
p0.join()
p1.join()
结果: