PicklingError: Can‘t pickle <function <lambda>...attribute lookup <lambda> on __main__ failed


dataset类的transform参数里使用了lambda函数

一、cell3 + cell4

# cell3
def target_transform(t):
    return torch.tensor([t]).float()

ds_train = datasets.ImageFolder("./dataset/cifar2/train/",
            transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("./dataset/cifar2/test/",
            transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())

print(ds_train.class_to_idx)
# cell4
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True,num_workers=3)
dl_valid = DataLoader(ds_valid,batch_size = 50,shuffle = True,num_workers=3)

cell3中datasets.ImageFolder使用了lambda函数,cell4中num_workers被设置为3,这两个因素共同作用导致报错。

二、torch.utils.data

torch.utils.data.DataLoader默认采用单进程(主进程)来加载数据,但可以通过num_workers设置同时使用几个子进程,num_workers=0表示只使用主进程。这里的workers由pytorch提供,其实现依赖于python的multiprocessing,其实现在windows下和unix下是不同的。

  • unix下默认采用fork(),子进程通过从父进程那里继承来的地址空间直接访问dataset和代码中其他带参数的函数
  • windows下默认采用spawn(),这时候会另起一个python解释器来执行主代码(main script),之后通过pickle序列化接收dataset, collate_fn以及其他参数来执行主代码内部需要由workers来执行的代码

采用spawn()的时候,worker_init_fn参数不能为unpicklable对象,例如lambda函数。

三、其他

  1. 由于spawn()的存在,linux下不会报错的代码在windows下可能会报错,所以在windows下使用多进程加载数据的时候要注意:
    (1)python脚本的主代码应该放在if name == ‘main’:内,这样它们就不会在worker子进程启动的时候再次运行,而DataLoader的构造是不需要被重复执行的,所以这部分代码也应该放在这里(比如cell4)
    (2)任何自定义的内容,包括参数collate_fn, worker_init_fn或dataset的具体代码(通常是函数的形式)则要放在__main__外面(比如cell3)。在pickle序列化的过程中对于函数传递的是引用而非二进制代码,所以要使worker子进程正常工作,这一步是必须的。
  2. torchvision.transforms.ToTensor,把读入的图片转为tensor
  3. torchvision.transforms.Compose,把对读入图片的各种处理组合起来

四、解决办法

方法一:依旧是在jupyter notebook环境中,把num_workers=3改为num_workers=0
方法二:由于cell3中lambda函数的存在,无论是notebook还是脚本中的__main__都不能实现子进程num_workers>0,所以把cell3改为:

ds_train = datasets.ImageFolder("./dataset/cifar2/train/",
            transform = transform_train)
ds_valid = datasets.ImageFolder("./dataset/cifar2/test/",
            transform = transform_train)

print(ds_train.class_to_idx)
# 也就是把target_transform = lambda ...删掉
# 可以打印一下ds_train[0]和type(ds_train[0][1])看一下

再把cell5改为:

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

# 查看部分样本
from matplotlib import pyplot as plt 

plt.figure(figsize=(8,8)) 
for i in range(9):
    img,label = ds_train[i]
    img = img.permute(1,2,0)
    ax=plt.subplot(3,3,i+1)
    ax.imshow(img.numpy())
    ax.set_title("label = %d"%label)
    ax.set_xticks([])
    ax.set_yticks([]) 
plt.show()
# 也就是把label.item()改为label

参考

  • 25
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值