最近在做一个强化学习的项目,运用多进程分布训练时遇到了段错误的问题,这里记录下解决的过程思路和方案。
由于智能体与环境交互的过程涉及到了第三方的程序以及大量的文件读写操作,使得整个实验过程非常慢,为了解决交互部分的速度瓶颈,采用Ape-X( Distributed Prioritized Experience Replay)的分布式训练思路,即多个actor负责与环境交互,得到的交互数据存储到公共replay memory中,一个leaner负责从memory中抽样训练更新网络。
由于Pytorch在多进程方面的封装较好,我采用torch.multiprocessing包来实现多进程,并通过其中的Queue队列来实现进程间通信,也就是actor将交互数据发送给learner。主要代码结构简化如下:
def actor(q):
# 创建环境
...
while True:
# 获取交互数据 batch 类型为Tensor
...
q.put(batch)
def learner(q)
# 创建memory
memory = Memory()
...
while True:
batch = q.get() # <--- *** 产生 SegFault的地方 ***
memory.push(batch)
update_model()
if __name__ == '__main__':
# 创建模型、优化器等
model = DQN()
model.share_memory()
...