在使用PyTorch类的项目在windows环境下运行的时候尤其是CPU模式下运行的时候经常就会报各种各样的错误,尤其是跟DataLoader相关的,这里的报错就是因为DataLoader在windows下多线程加载数据集报错导致的,感觉torch对这个问题好像一直没有提上日程去解决它。
原始代码如下:

我常用的解决方法就是将nw值改为0即可,事实上这样的确也是起作用的。
但是今天我这么修改的时候又报错了,如下所示:

详情如下:
ValueError: persistent_workers option needs num_workers > 0
感觉这里的问题是由于persistent_workers参数导致的,我查询了一下DataLoader中的persistent_workers参数:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False,
drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2,
persistent_workers=False)
如果persistent_workers为True,数据加载器将不会在数据集运行完一个Epoch后关闭worker进程。这允许维护worker数据集实例保持激活。(默认值:False)
意思是运行完一个Epoch后并不会关闭worker进程,而是保持现有的worker进程继续进行下一个Epoch的数据加载。好处是Epoch之间不必重复关闭启动worker进程,加快训练速度。
也就是说作者这里设定persistent_workers为TRUE是为了提升训练速度,但是按照解决多进程数据加载报错的方法无意间触发了【num_workers > 0】的硬性要求,所以这里兼顾训练速度的办法就是将nw值改为1,如下:

当然了也可以选择将persistent_workers设置为False,如下:

重新执行,可以看到模型已经正常训练开始了:

记录备忘!

1616

被折叠的 条评论
为什么被折叠?



