pytorch 使用 multiprocess库 让 Dataloader 不再报错 AttributeError: Can‘t pickle local object

2023-4-27 更新
另外一种使用 torchdata 库的解决方法 https://blog.csdn.net/ONE_SIX_MIX/article/details/130405330


pytorch 的 dataloader 默认使用 python 自带的多进程库 multiprocessing ,它又使用 pickle 作为序列化库。
pickle 库只能储存一些简单类型。如果 dataset 中使用 lambda 函数对象,将会导致出现这样的错误 AttributeError: Can’t pickle local object

multiprocess 的 pip 安装方法

pip install -U multiprocess

第三方库的多进程库 multiprocess ,使用 dill 库,它可以序列化复杂的类型,例如 lambda 函数。

我们将用它替代掉 dataloader 中默认的多进程库。

!注意!使用 multiprocess 库替代后,可能会占用更多的内存。速度没差别。

以下为替代示例
以下变量 use_multiprocess 是开关
将其 use_multiprocess 设为 False ,即为使用默认的多进程库,执行以下代码会报错。
将 use_multiprocess 设为 True,即为使用第三方多进程库,以下代码可以正常执行

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset

import multiprocessing
multiprocessing.set_start_method('spawn', force=True)

# Flag
use_multiprocess = True

if use_multiprocess:
    import multiprocess
    multiprocess.set_start_method('spawn', force=True)

    torch.utils.data.dataloader.python_multiprocessing = multiprocess
    new_multiprocess_ctx = multiprocess.get_context()

else:
    new_multiprocess_ctx = None


class SimpleDataset(Dataset):
    def __init__(self):
        # lambda function here
        self.func = lambda x: x+1

    def __len__(self):
        return 1000

    def __getitem__(self, i):
        return self.func(i)


if __name__ == '__main__':
    ds = SimpleDataset()

    dl = DataLoader(ds, 1, multiprocessing_context=new_multiprocess_ctx, num_workers=2)

    for x in dl:
        print(x)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值