1.问题分析
torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=True)
在Pytorch中通过以上代码获得一个对象用于加载数据集中的样本数据,其中num_workers参数用于指定执行数据加载的子进程个数,当该值为0时仅使用主进程进行数据集加载工作,而当该值大于0时,则会创建对应数量的子进程负责相关的数据集加载工作。
该方法使用multiprocessing包来实现创建新的子进程以及多进程同步,在Linux相关系统中通过使用fork函数能够正常的从对应代码(使用fork语句处)开始执行新的子进程,而在Windows相关系统则没有对应函数。
在Windows环境下,通过在训练文件Train.py中加入如下代码:
print('Running', __name__)
可获得终端输出信息如下:
Running __main__
……
Running __mp_main__
Running __mp_main__
Running __mp_main__
Running __mp_main__
Running __mp_main__
Running __mp_main__
Running __mp_main__
Running __mp_main__
也就是当主进程执行后,代码运行至Dataloader函数时,将创建新的子进程,并且这些子进程会从头开始执行相应的Train.py文件代码。 这种特性与Linux下使用fork创建子进程是完全不同的。
这种差异将导致在Windows环境下,相关训练代码中会存在一些代码被重复执行,由此导致一些异常错误。
例如通常在完成数据读入工作后,子进程应当立即结束(通过在Dataloader函数后输出__name__可以容易发现只有主进程会执行后续代码)。而若由于一些代码被重复执行导致子进程无法进行数据读入工作,而是停留在某行代码,将最终导致子进程无法正常结束并占用一定资源无法释放。笔者遇到的问题就是由于这些子进程重复执行网络实例生成的代码,最终导致内存不足,无法完成正常训练。
2.解决办法
将训练代码中不希望被重复执行的代码加以区分即可,最简单的实现方式即判断__name__的内容,具体代码实现如下:
if __name__ == '__main__':
# 这里放一些不希望被重复执行的代码
train()
如果你希望一些代码在Dataloader加载数据时被重复执行,也可以这样写:
if __name__ == '__main__':
train()
elif __name__ == '__mp_main__'
# 这里可以放一些希望被重复执行的代码
参考文章:
if name == ‘main’ 如何正确理解?
windows下pytorch Dataloader中多进行重复执行代码的问题解决(重复创建tensorboard日志)
// 全文完
因笔者能力有限,若文章内容存在错误或不恰当之处,欢迎留言、私信批评指正。
Email:YePeanut[at]foxmail.com