解决问题的思路是,首先找到导致序列化失败的对象,然后将其覆盖为一个正常对象。
网上能找到很多相关解答,但是往往都是授人以鱼,没有授人以渔。这个问题的出现场景很多,但是归根结底,都是因为定义了一些不可被 pickle
序列化的对象,然后又将这些对象作为 multiprocessing
的参数传入了。所以要解决这个问题,我们必须知道是哪个对象不可序列化。在了解了 multiprocessing
的流程后,排查过程其实是很简单的。
先贴一下我的报错信息,我是在运行 DDP 的时候遇到了无法序列化的问题。具体过程是, DDP 在创建数据进程时调用了 multiprocessing
,而传入 multiprocessing
的参数不可序列化。
File "/home/pai/lib/python3.6/site-packages/mmcv/runner/epoch_based_runner.py", line 47, in train
for i, data_batch in enumerate(self.data_loader):
File "/home/pai/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 359, in __iter__
return self._get_iterator()
File "/home/pai/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 305, in _get_iterator
return _MultiProcessingDataLoaderIter(self)
File "/home/pai/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 918, in __init__
w.start()
File "/home/pai/lib/python3.6/multiprocessing/process.py", line 105, in start
self._popen = self._Popen(self)
File "/home/pai/lib/python3.6/multiprocessing/context.py", line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "/home/pai/lib/python3.6/multiprocessing/context.py", line 284, in _Popen
return Popen(process_obj)
File "/home/pai/lib/python3.6/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/home/pai/lib/python3.6/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/home/pai/lib/python3.6/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/home/pai/lib/python3.6/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
TypeError: can't pickle _thread.RLock objects
看到这个报错,我的第一反应是数据进程间的同步问题,因为 DDP 使用多进程加载数据,而报错信息里又出现了 _thread.RLock
。但是后来发现这个 _thread.RLock
只是一个不能被 pickle
序列化的普通类型,报错本身和进程间同步是无关的。multiprocessing
使用序列化的目的主要是将 python 对象编码为二进制流,通过管道发送给子进程。
下面介绍一个通用的排查方法。首先定位到报错的前一行,也就是 ForkingPickler(file, protocol).dump(obj)
的前一行。根据报错信息可以看到,这行代码在我的 multiprocessing/popen_spawn_posix.py
第 60 行,行号可能因 python 版本而异。现在我们知道, obj
的其中一个属性(或者是属性的属性)一定是不可序列化的,而且这个属性的类型是 _thread.RLock
。为了找到这个属性,只需要对 obj
的属性逐一进行序列化测试。如果发现 obj
的其中一个属性不可序列化,但是这个属性的类型不是 _thread.RLock
,就说明这个属性的某个属性不可序列化,需要递归查找。这样就可以找到导致问题的对象了,将其删除或者替换都可以解决问题。
对于 PyTorch Dataloader 来说,序列化的 python 对象主要是数据集相关的,具体有哪些可以看 PyTorch 1.9.1 源码。注意报错就是在 918 行 w.start()
处发生的。进一步调试会发现,序列化失败的对象就是这里定义的 w
。但是 w
本身没有引入无法序列化的参数,所以最有可能导致序列化失败的对象就在 args
中。
# torch/utils/data/dataloader.py, line 904~918
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
...
w.start()
一般来说, dataset
可能是用户自定义的类,所以出现问题的可能也较大,可以优先尝试用 pickle
对 dataset
进行序列化。如果失败,则说明导致序列化失败的对象是 dataset
的一个属性,可以逐一对 dataset
的属性进行序列化测试。当然其他对象也有可能出问题,保险起见可以使用上面的通用方法进行排查。
我定位到的问题是 dataset
中有一个 logger
对象。在 python 3.7 以前, logger
对象是不能序列化的,因为其中包含一些 lock
对象,而这些对象的类型是 _thread.RLock
。最简单的修改方法就是把这些 logger
全都删掉。或者也可以将 logger
设置为全局变量,从而避免序列化。下面是一个序列化测试的 toy example 。
import io
import pickle
import sys
import threading
print(sys.version_info)
lock = threading.RLock()
buffer = io.BytesIO()
pickle.dump(lock, buffer)
期望输出如下,可以看到报错信息和 multiprocessing
一致。
sys.version_info(major=3, minor=8, micro=9, releaselevel='final', serial=0)
Traceback (most recent call last):
File "debug.py", line 9, in <module>
pickle.dump(lock, buffer)
TypeError: cannot pickle '_thread.RLock' object
参考
TypeError: can‘t pickle _thread.RLock object解决