解决 torch.cat(): input types can‘t be cast to the desired output type Byte

最近使用 U2Net 训练模型的时候,遇到了下面的错误:


RuntimeError: torch.cat(): input types can't be cast to the desired output type Byte

错误堆栈信息如下

Original Traceback (most recent call last): File "/datapython3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop data = fetcher.fetch(index) ^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch return self.collate_fn(data) ^^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 268, in default_collate return collate(batch, collate_fn_map=default_collate_fn_map) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 127, in collate return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 127, in return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 119, in collate return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 165, in collate_tensor_fn return torch.stack(batch, 0, out=out) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
RuntimeError: torch.cat(): input types can't be cast to the desired output type Byte

原因:个人猜想是多个 worker  在一起工作引起的并发问题。

解决方法一:

在构建 DataLoader 实例的时候,把 workers 设置为 0 即可。

缺点:会导致训练速度变慢
 

train_dataset = SalObjDataset(
    img_name_list=train_img_name_list,
    lbl_name_list=train_label_name_list,
    transform=transforms.Compose([
        RescaleT(320),
        # RandomCrop(288),
        ToTensorLab(flag=0)]))
train_dataset.auto_collation = True
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=0)

解决方法二:

修改 pytorch 里面的代码。

在上面的堆栈中,显示 /datapython3.11/site-packages/torch/utils/data/_utils/collate.py:165 报错了

我们打开 collate.py 文件,找到 collate_tensor_fn 这个函数,把 if 语句的内容注释掉就可以了

缺点:需要修改pytorch的代码,会增加多一次内存拷贝 

def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
    elem = batch[0]
    out = None
    # if torch.utils.data.get_worker_info() is not None:
    #     # If we're in a background process, concatenate directly into a
    #     # shared memory tensor to avoid an extra copy
    #     numel = sum(x.numel() for x in batch)
    #     storage = elem._typed_storage()._new_shared(numel, device=elem.device)
    #     out = elem.new(storage).resize_(len(batch), *list(elem.size()))
    return torch.stack(batch, 0, out=out)

.

.

  • 10
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值