torch.distributed.barrier作用
Pytorch在分布式训练过程中,对于数据的读取是采用主进程预读取并缓存,然后其它进程从缓存中读取,不同进程之间的同步通信需要通过torch.distributed.barrier()实现
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
主要就是通过对其他进程进行阻塞来等所有的进程的计算都完毕之后在进行后续的计算。
关于dist.all_reduce和dist.all_gather
图片截取自pytorch官方文档
代码来源:
Pytorch barrier
源码阅读:
def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
"""
Synchronizes all processes.
This collective blocks processes until the whole group enters this function,
if async_op is False, or if async work handle is called on wait().
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
device_ids ([int], optional): List of device/GPU ids.
Valid only for NCCL backend.
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
"""
if _rank_not_in_group(group):
_warn_not_in_group("barrier")
return
opts = BarrierOptions()
if device_ids is not None:
if get_backend(group) != Backend.NCCL:
raise RuntimeError(
"Function argument device_ids not supported "
"for the selected backend {}".format(get_backend(group))
)
if isinstance(device_ids, list):
opts.device_ids = device_ids
else:
raise RuntimeError(
"Invalid function argument: " "device_ids type should be List[int]"
)
if group is None:
default_pg = _get_default_group()
work = default_pg.barrier(opts=opts)
else:
work = group.barrier(opts=opts)
if async_op:
return work
else:
work.wait()