字符序列无法转tensor,因此导致无法使用DataParallel做数据并行,解决方法

解决方法:自定义DataParallel类,并重写分散方法。


from torch.nn.parallel._functions import Scatter
from torch.nn.parallel import DataParallel
import torch
# This code was copied from torch.nn.parallel and adapted for DataParallel to chunk lists instead of duplicating them
# (this is really all this code is here for)
def scatter(inputs, target_gpus, dim=0):
    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            return Scatter.apply(target_gpus, None, dim, obj)
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            if not isinstance(obj, list) or len(obj) == 0 or len(target_gpus) == 0:
                return []
            num_gpus = len(target_gpus)
            num_samples = len(obj)
            samples_per_gpu = num_samples // num_gpus
            remaining_samples = num_samples % num_gpus

            distributed_samples = []
            start_idx = 0
            for gpu_idx in range(num_gpus):
                gpu_samples_count = samples_per_gpu + (1 if gpu_idx < remaining_samples else 0)
                gpu_samples = obj[start_idx:start_idx + gpu_samples_count]
                distributed_samples.append(gpu_samples)
                start_idx += gpu_samples_count

            return distributed_samples
            # size = len(obj) // len(target_gpus)
            # remaining_samples = len(obj) % len(target_gpus)
            # res = [obj[i * size:(i + 1) * size] for i in range(len(target_gpus))]
            # if remaining_samples==0:
            #     return res
            # end = (len(target_gpus)+1)*size
            # for list1 in res:
            #     list1.append(obj[end])
            #     end+=1
            #     if end==len(obj):
            #         break
            # return res
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        return [obj for _ in target_gpus]
    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
    try:
        return scatter_map(inputs)
    finally:
        scatter_map = None


def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
    inputs = scatter(inputs, target_gpus, dim) if inputs else []
    kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
    if len(inputs) < len(kwargs):
        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
    elif len(kwargs) < len(inputs):
        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
    inputs = tuple(inputs)
    kwargs = tuple(kwargs)
    return inputs, kwargs

class DataParallelV2(DataParallel):
    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值