pytorch学习之 - 跨卡收集结果

核心步骤:

  1. 结果做序列化
  2. 跨卡之间做all_gather操作
  3. 将结果做反序列化

序列化

In [21]: import pickle
    ...: import torch
    ...: 
    ...: result_part = [1, 2, 3, 4]
    ...: part_tensor = torch.tensor(bytearray(pickle.dumps(result_part)), dtype=torch.uint8)
    ...: recover_result = pickle.loads(part_tensor.cpu().numpy().tobytes())

In [22]: result_part
Out[22]: [1, 2, 3, 4]

In [23]: recover_result
Out[23]: [1, 2, 3, 4]

代码

import pickle
from typing import Tuple

import torch
import torch.distributed as dist


def get_dist_info() -> Tuple[int, int]:
    """
    get distributed information.
    """
    if torch.__version__ < '1.0':
        initialized = dist._initialized
    else:
        if dist.is_available():
            initialized = dist.is_initialized()
        else:
            initialized = False
    if initialized:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size

def collect_results_gpu(result_part, size=None):
    """

    Args:
        result_part: result from single gpu.
        size: the length of the expected collected result.

    """
    rank, world_size, _ = get_dist_info()
    # dump result part to tensor with pickle
    part_tensor = torch.tensor(
        bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
    # gather all result part tensor shape
    shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
    shape_list = [shape_tensor.clone() for _ in range(world_size)]
    torch.cuda.synchronize()
    dist.all_gather(shape_list, shape_tensor)
    # padding result part tensor to max length
    shape_max = torch.tensor(shape_list).max()
    part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
    part_send[:shape_tensor[0]] = part_tensor
    part_recv_list = [
        part_tensor.new_zeros(shape_max) for _ in range(world_size)
    ]
    # gather all result part
    dist.all_gather(part_recv_list, part_send)

    if rank == 0:
        part_list = []
        for recv, shape in zip(part_recv_list, shape_list):
            part_list.append(
                pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
        # sort the results
        ordered_results = []
        for res in zip(*part_list):
            ordered_results.extend(list(res))
        # the dataloader may pad some samples
        if size:
            ordered_results = ordered_results[:size]
        return ordered_results
    else:
        return None
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值