pytorch 使用 multiprocess库 让 Dataloader 不再报错 AttributeError: Can‘t pickle local object

2023-4-27 更新
另外一种使用 torchdata 库的解决方法 https://blog.csdn.net/ONE_SIX_MIX/article/details/130405330


pytorch 的 dataloader 默认使用 python 自带的多进程库 multiprocessing ,它又使用 pickle 作为序列化库。
pickle 库只能储存一些简单类型。如果 dataset 中使用 lambda 函数对象,将会导致出现这样的错误 AttributeError: Can’t pickle local object

multiprocess 的 pip 安装方法

pip install -U multiprocess

第三方库的多进程库 multiprocess ,使用 dill 库,它可以序列化复杂的类型,例如 lambda 函数。

我们将用它替代掉 dataloader 中默认的多进程库。

!注意!使用 multiprocess 库替代后,可能会占用更多的内存。速度没差别。

以下为替代示例
以下变量 use_multiprocess 是开关
将其 use_multiprocess 设为 False ,即为使用默认的多进程库,执行以下代码会报错。
将 use_multiprocess 设为 True,即为使用第三方多进程库,以下代码可以正常执行

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset

import multiprocessing
multiprocessing.set_start_method('spawn', force=True)

# Flag
use_multiprocess = True

if use_multiprocess:
    import multiprocess
    multiprocess.set_start_method('spawn', force=True)

    torch.utils.data.dataloader.python_multiprocessing = multiprocess
    new_multiprocess_ctx = multiprocess.get_context()

else:
    new_multiprocess_ctx = None


class SimpleDataset(Dataset):
    def __init__(self):
        # lambda function here
        self.func = lambda x: x+1

    def __len__(self):
        return 1000

    def __getitem__(self, i):
        return self.func(i)


if __name__ == '__main__':
    ds = SimpleDataset()

    dl = DataLoader(ds, 1, multiprocessing_context=new_multiprocess_ctx, num_workers=2)

    for x in dl:
        print(x)

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
### 回答1: 当出现"AttributeError: can't pickle local object 'convert_frame.<locals>"错误时,通常是因为在保存模型时,包含了不能被序列化的对象。 这个错误可能是由于在模型的构建过程中使用Python的lambda函数或局部函数引起的。这些函数是本地定义的,无法被pickle模块序列化,因此导致了这个错误。 为了解决这个问题,我们可以将这些无法被序列化的函数定义转移到全局范围内。将它们定义为全局函数或方法会使它们能够正确地进行序列化和保存。 例如,如果出现错误的函数是"convert_frame",我们可以将其重构为全局函数,并将其定义放在模型保存代码之前,以确保模型能够正确保存。 另外,这个错误也可能与使用了不支持pickle模块的第三方相关。在这种情况下,我们需要将这些对象排除在保存操作之外,或者尝试使用其他序列化工具来替换pickle。 在总结,当出现"AttributeError: can't pickle local object 'convert_frame.<locals>"错误时,我们需要检查模型构建过程中是否使用了lambda函数或局部函数,并将其转移到全局范围内,或者排除不能被序列化的对象,或者尝试使用其他序列化工具。 ### 回答2: torch.save函数在保存模型时会通过pickle将模型对象序列化保存到文件中。然而,在执行pickle序列化操作时,会遇到一些限制,其中之一是对于局部对象的pickle支持较弱。 而报错信息中的"AttributeError: can't pickle local object 'convert_frame.<locals>"则提示在保存过程中遇到了一个名为'convert_frame'的局部对象,由于pickle的限制,该对象无法被正常序列化保存。 解决该问题的一种方法是将'convert_frame'对象改为全局对象,即将其定义在函数外部。这样可以避免pickle对局部对象的限制。 另一种方法是使用torch.jit进行模型的保存和加载,替代pickle。torch.jit可以对模型进行更高效的序列化和反序列化,且不受局部对象的限制。可以使用torch.jit.script装饰器将模型的定义封装为脚本,然后使用torch.jit.save函数保存模型,而加载模型时,使用torch.jit.load函数加载模型即可。 综上所述,遇到该错误,可以考虑将局部对象改为全局对象或使用torch.jit进行模型的保存和加载,以解决torch.save attributeerror: can't pickle local object 'convert_frame.<locals>'的问题。 ### 回答3: 报错"AttributeError: Can't pickle local object 'convert_frame.<locals>" 是由于torch.save()函数中调用了一个无法被pickle序列化的本地对象 'convert_frame.<locals>'。 出现这样的问题可能有几种原因。首先,可能是因为该本地对象内部有lambda表达式、嵌套函数或闭包函数,这些函数在pickle时无法被序列化。其次,可能是由于该本地对象包含了无法被序列化的其他对象或不可序列化的数据类型。 要解决这个问题,可以考虑以下几种方法: 1. 将不可序列化的本地对象改为全局函数:如果 'convert_frame.<locals>' 是一个嵌套在函数中的本地对象,可以尝试将其提取出来,转为全局对象。 2. 使用不同的序列化方法:torch.save()函数使用pickle进行序列化,但pickle对于某些特定对象类型可能不起作用。可以使用其他可序列化的方法,如HDF5或JSON等。 3. 改写无法序列化的本地对象:如果 'convert_frame.<locals>' 包含无法序列化的其他对象,可以尝试将其改写为可序列化的形式。例如,将其属性中的不可序列化对象转换为可序列化的形式。 4. 使用torch.jit保存模型:torch.jit模块提供了一种将PyTorch模型编译为脚本的方法,可以通过torch.jit.save()函数保存模型并加载。 5. 检查其他可能引起问题的部分:如果以上方法都无效,可能需要检查代码中其他可能引起问题的部分,例如其他函数调用或使用。 以上是一些可能的解决方法,具体的解决方案需要根据具体情况进行调试和修改。希望可以对您有所帮助!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值