问题场景
今天在看KiU-Net的源码时,数据预处理非常耗时(共处理111套数据,耗时约1个小时)。
由于数据预处理是个一个可以并行的操作,因此想到可以通过多进程来加速这个预处理的过程。
源代码
"""
获取可用于训练网络的训练数据集
需要四十分钟左右,产生的训练数据大小3G左右
"""
import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
import shutil
from time import time
import numpy as np
from tqdm import tqdm
import SimpleITK as sitk
import scipy.ndimage as ndimage
import parameter as para
if os.path.exists(para.training_set_path):
shutil.rmtree(para.training_set_path)
new_ct_path = os.path.join(para.training_set_path, 'ct')
new_seg_dir = os.path.join(para.training_set_path, 'seg')
os.mkdir(para.training_set_path)
os.mkdir(new_ct_path)
os.mkdir(new_seg_dir)
start = time()
for file in tqdm(os.listdir(para.train_ct_path)):
# 将CT和金标准入读内存
print(os.path.join(para.train_ct_path, file))
ct = sitk.ReadImage(os.path.join(para.train_ct_path, file), sitk.sitkInt16)
ct_array = sitk.GetArrayFromImage(ct)
seg = sitk.ReadImage(os.path.join(para.train_seg_path, file.replace('volume', 'segmentation')), sitk.sitkUInt8)
seg_array = sitk.GetArrayFromImage(seg)
# 将金标准中肝脏和肝肿瘤的标签融合为一个
seg_array[seg_array > 0] = 1
# 将灰度值在阈值之外的截断掉
ct_array[ct_array > para.upper] = para.upper
ct_array[ct_array < para.lower] = para.lower
# 对CT数据在横断面上进行降采样,并进行重采样,将所有数据的z轴的spacing调整到1mm
ct_array = ndimage.zoom(ct_array, (ct.GetSpacing()[-1] / para.slice_thickness, para.down_scale, para.down_scale), order=3)
seg_array = ndimage.zoom(seg_array, (ct.GetSpacing()[-1] / para.slice_thickness, 1, 1), order=0)
# 找到肝脏区域开始和结束的slice,并各向外扩张slice
z = np.any(seg_array, axis=(1, 2))
start_slice, end_slice = np.where(z)[0][[0, -1]]
# 两个方向上各扩张slice
start_slice = max(0, start_slice - para.expand_slice)
end_slice = min(seg_array.shape[0] - 1, end_slice + para.expand_slice)
# 如果这时候剩下的slice数量不足size,直接放弃该数据,这样的数据很少,所以不用担心
if end_slice - start_slice + 1 < para.size:
print('!!!!!!!!!!!!!!!!')
print(file, 'have too little slice', ct_array.shape[0])
print('!!!!!!!!!!!!!!!!')
continue
ct_array = ct_array[start_slice:end_slice + 1, :, :]
seg_array = seg_array[start_slice:end_slice + 1, :, :]
# 最终将数据保存为nii
new_ct = sitk.GetImageFromArray(ct_array)
new_ct.SetDirection(ct.GetDirection())
new_ct.SetOrigin(ct.GetOrigin())
new_ct.SetSpacing((ct.GetSpacing()[0] * int(1 / para.down_scale), ct.GetSpacing()[1] * int(1 / para.down_scale), para.slice_thickness))
new_seg = sitk.GetImageFromArray(seg_array)
new_seg.SetDirection(ct.GetDirection())
new_seg.SetOrigin(ct.GetOrigin())
new_seg.SetSpacing((ct.GetSpacing()[0], ct.GetSpacing()[1], para.slice_thickness))
sitk.WriteImage(new_ct, os.path.join(new_ct_path, file))
sitk.WriteImage(new_seg, os.path.join(new_seg_dir, file.replace('volume', 'segmentation').replace('.nii', '.nii.gz')))
解决思路
通过代码我们可以看到作者是将所有文件名生成一个列表,循环列表来读取每一个CT图像以及它的mask,再进行数据预处理。实际上我们可以将列表分成多个子集,然后每一个子集用一个进程来进行读取与预处理的操作
多进程模板
import multiprocessing as mp
def func(param):
# DO sth.
if __name__ == '__main__': # Windows系统下一定要有这个!!
pool = mp.Pool(n) # 调用进程池,输入想要调用的进程数
param_list = [param1,param2,...]
for param in param_list:
#apply_async() 异步非阻塞方法,不用等待当前进程执行完再进行下一个进程
pool.apply_async(func,args=(param,)) # apply_async()方法的参数传递必须是tuple形式
pool.close() # 进程池关闭,不再加入新的进程
pool.join() # 阻塞主进程的执行,等待子进程运行结束再运行主进程后面的代码
print("Done")
完整代码
import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
import shutil
import time
import numpy as np
import SimpleITK as sitk
import scipy.ndimage as ndimage
import multiprocessing as mp
import parameter as para
if os.path.exists(para.training_set_path):
shutil.rmtree(para.training_set_path)
new_ct_path = os.path.join(para.training_set_path, 'ct')
new_seg_dir = os.path.join(para.training_set_path, 'seg')
os.makedirs(para.training_set_path,exist_ok=True)
os.makedirs(new_ct_path,exist_ok=True)
os.makedirs(new_seg_dir,exist_ok=True)
def data_preprocess(data_list:list):
for file in data_list:
# 将CT和金标准入读内存
ct = sitk.ReadImage(os.path.join(para.train_ct_path, file), sitk.sitkInt16)
ct_array = sitk.GetArrayFromImage(ct)
seg = sitk.ReadImage(os.path.join(para.train_seg_path, file.replace('volume', 'segmentation')), sitk.sitkUInt8)
seg_array = sitk.GetArrayFromImage(seg)
# 将金标准中肝脏和肝肿瘤的标签融合为一个
seg_array[seg_array > 0] = 1
# 将灰度值在阈值之外的截断掉
ct_array[ct_array > para.upper] = para.upper
ct_array[ct_array < para.lower] = para.lower
# 对CT数据在横断面上进行降采样,并进行重采样,将所有数据的z轴的spacing调整到1mm
ct_array = ndimage.zoom(ct_array, (ct.GetSpacing()[-1] / para.slice_thickness, para.down_scale, para.down_scale), order=3)
seg_array = ndimage.zoom(seg_array, (ct.GetSpacing()[-1] / para.slice_thickness, 1, 1), order=0)
# 找到肝脏区域开始和结束的slice,并各向外扩张slice
z = np.any(seg_array, axis=(1, 2))
start_slice, end_slice = np.where(z)[0][[0, -1]]
# 两个方向上各扩张slice
start_slice = max(0, start_slice - para.expand_slice)
end_slice = min(seg_array.shape[0] - 1, end_slice + para.expand_slice)
# 如果这时候剩下的slice数量不足size,直接放弃该数据,这样的数据很少,所以不用担心
if end_slice - start_slice + 1 < para.size:
print('!!!!!!!!!!!!!!!!')
print(file, 'have too little slice', ct_array.shape[0])
print('!!!!!!!!!!!!!!!!')
continue
ct_array = ct_array[start_slice:end_slice + 1, :, :]
seg_array = seg_array[start_slice:end_slice + 1, :, :]
# 最终将数据保存为nii
new_ct = sitk.GetImageFromArray(ct_array)
new_ct.SetDirection(ct.GetDirection())
new_ct.SetOrigin(ct.GetOrigin())
new_ct.SetSpacing((ct.GetSpacing()[0] * int(1 / para.down_scale), ct.GetSpacing()[1] * int(1 / para.down_scale), para.slice_thickness))
new_seg = sitk.GetImageFromArray(seg_array)
new_seg.SetDirection(ct.GetDirection())
new_seg.SetOrigin(ct.GetOrigin())
new_seg.SetSpacing((ct.GetSpacing()[0], ct.GetSpacing()[1], para.slice_thickness))
sitk.WriteImage(new_ct, os.path.join(new_ct_path, file))
sitk.WriteImage(new_seg, os.path.join(new_seg_dir, file.replace('volume', 'segmentation').replace('.nii', '.nii.gz')))
if __name__ == '__main__':
train_ct_list = os.listdir(para.train_ct_path)
# 将train_ct_list分成4份
nums = len(train_ct_list)//4
data_list = []
for i in range(3):
data_list.append(train_ct_list[i*nums:(i+1)*nums])
data_list.append(train_ct_list[3*nums:])
begin = time.time()
pool = mp.Pool(4) # 4为进程数,可以调用multiprocessing.cpu_count()来查看cpu数量
for paths in data_list:
pool.apply_async(data_preprocess,args=(paths,))
pool.close()
pool.join()
print("处理用时:{:.3f}s".format(time.time()-begin))
实现结果
在测试集20套数据上进行测试,调用4个进程耗时约在50 ~ 60s之间,换成111套数据的训练集实测20 ~ 25分钟左右,数据量增加后速度有所下降,但还是比单一进程处理要快很多。