gpu和cpu的多进程调用不一样,下面的代码有两行我加了注释,根据需求注释掉其中一行就行,若gpu则需要调用torch,否则用另一个就行了。看懂下面这个就可以自己写自己的多进程了,很简单,学会的话收藏点个赞吧,我会很开心的。
大体模版就是这样写,根据自己的任务具体写,这里我是读文件内容,再传给data_write函数(写在for循环里那),多进程执行这个函数,同学们可以编写自己的函数。另外我在data_write外面套了一个LogExceptions函数,这个函数的作用是可以抓取多进程执行过程中的报错信息(因为多进程过程中好多报错并不会打印出来),详细可以看我另一篇文章。
if __name__ == '__main__':
starttime_a = datetime.datetime.now()
torch.multiprocessing.log_to_stderr()
ctx = torch.multiprocessing.get_context("spawn")
multi_num = os.cpu_count() // 2
if multi_num > 3:
multi_num = 3
# pool = multiprocessing.Pool(processes=multi_num) #cpu版本
pool = ctx.Pool(multi_num) #gpu版本
rows=[]
print('reading...',path)
with open(path,'r',encoding='utf-8') as csvfile:
reader = csv.reader(csvfile)
rows = [row for row in reader]
print('len',len(rows))
step = 2000
cnt = int(len(rows) / step) + 1
print('总进程个数',cnt)
for i in range(cnt):
multi_rows = rows[i*step:(i+1)*step]
# data_write(multi_rows,model,preprocess,device)
pool.apply_async(LogExceptions(data_write),(multi_rows,model,preprocess,device,))
pool.close()
pool.join()
endtime_a = datetime.datetime.now()
print('down')
print("总耗时:", endtime_a - starttime_a)