pytorch DataParallel 数据对象分割问题

25 篇文章 1 订阅
14 篇文章 0 订阅

报错信息

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

检查方式

在模型的forward()函数中加入测试代码检查数据位置,例如:

print("blocks:%s, batch:%s" % (self.encoder_blocks[0].KPConv.weights.device, batch.features.device))

输出:

blocks:cuda:0, batch:cuda:0
blocks:cuda:1, batch:cuda:0

我这里的blocks是模型参数,batch是类型为S3DISCustomBatch的数据对象。发现模型被分配到各个GPU上,但是数据只存在于cuda:0上,故cuda:1运算时报错。

问题原因

阅读DataParallel源码发现,由于我传入模型的数据格式为对象(object),DataParallel无法分割对象类型的数据。
DataParallel只能自动分割Tensor, tuple, list, dict这几种类型的数据,代码如下所示。

# torch.nn.parallel.scatter_gather

def scatter(inputs, target_gpus, dim=0):
    r"""
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    """
    def scatter_map(obj):
        # print("do scatter deep: ", target_gpus, obj)
        if isinstance(obj, torch.Tensor):
            return Scatter.apply(target_gpus, None, dim, obj)
        if is_namedtuple(obj):
            return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return [list(i) for i in zip(*map(scatter_map, obj))]
        if isinstance(obj, dict) and len(obj) > 0:
            return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
        return [obj for targets in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
    try:
        res = scatter_map(inputs)
    finally:
        scatter_map = None
    return res

解决方法

在上述的scatter_map函数中加入我们使用的数据对象的判断,并实现该对象的分割代码。

        if str(type(obj)) == "<class 'datasets.S3DIS.S3DISCustomBatch'>":
            return [type(obj)(i) for i in scatter_map(obj.__dict__)]

同时修改我们对象的构造函数,加入如下代码,使其能简单地被复制

        if isinstance(input_list, dict):
            self.__dict__ = input_list
            return
  • 5
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值