2023-4-27 更新
另外一种使用 torchdata 库的解决方法 https://blog.csdn.net/ONE_SIX_MIX/article/details/130405330
pytorch 的 dataloader 默认使用 python 自带的多进程库 multiprocessing ,它又使用 pickle 作为序列化库。
pickle 库只能储存一些简单类型。如果 dataset 中使用 lambda 函数对象,将会导致出现这样的错误 AttributeError: Can’t pickle local object
multiprocess 的 pip 安装方法
pip install -U multiprocess
第三方库的多进程库 multiprocess ,使用 dill 库,它可以序列化复杂的类型,例如 lambda 函数。
我们将用它替代掉 dataloader 中默认的多进程库。
!注意!使用 multiprocess 库替代后,可能会占用更多的内存。速度没差别。
以下为替代示例
以下变量 use_multiprocess 是开关
将其 use_multiprocess 设为 False ,即为使用默认的多进程库,执行以下代码会报错。
将 use_multiprocess 设为 True,即为使用第三方多进程库,以下代码可以正常执行
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
# Flag
use_multiprocess = True
if use_multiprocess:
import multiprocess
multiprocess.set_start_method('spawn', force=True)
torch.utils.data.dataloader.python_multiprocessing = multiprocess
new_multiprocess_ctx = multiprocess.get_context()
else:
new_multiprocess_ctx = None
class SimpleDataset(Dataset):
def __init__(self):
# lambda function here
self.func = lambda x: x+1
def __len__(self):
return 1000
def __getitem__(self, i):
return self.func(i)
if __name__ == '__main__':
ds = SimpleDataset()
dl = DataLoader(ds, 1, multiprocessing_context=new_multiprocess_ctx, num_workers=2)
for x in dl:
print(x)