核心步骤:
- 结果做序列化
- 跨卡之间做all_gather操作
- 将结果做反序列化
序列化
In [21]: import pickle
...: import torch
...:
...: result_part = [1, 2, 3, 4]
...: part_tensor = torch.tensor(bytearray(pickle.dumps(result_part)), dtype=torch.uint8)
...: recover_result = pickle.loads(part_tensor.cpu().numpy().tobytes())
In [22]: result_part
Out[22]: [1, 2, 3, 4]
In [23]: recover_result
Out[23]: [1, 2, 3, 4]
代码
import pickle
from typing import Tuple
import torch
import torch.distributed as dist
def get_dist_info() -> Tuple[int, int]:
"""
get distributed information.
"""
if torch.__version__ < '1.0':
initialized = dist._initialized
else:
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def collect_results_gpu(result_part, size=None):
"""
Args:
result_part: result from single gpu.
size: the length of the expected collected result.
"""
rank, world_size, _ = get_dist_info()
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
torch.cuda.synchronize()
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
dist.all_gather(part_recv_list, part_send)
if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_list.append(
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
if size:
ordered_results = ordered_results[:size]
return ordered_results
else:
return None