Python多进程对数据进行预处理

问题场景

今天在看KiU-Net的源码时,数据预处理非常耗时(共处理111套数据,耗时约1个小时)。
由于数据预处理是个一个可以并行的操作,因此想到可以通过多进程来加速这个预处理的过程。

源代码

源码链接

"""
获取可用于训练网络的训练数据集
需要四十分钟左右,产生的训练数据大小3G左右
"""

import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
import shutil
from time import time

import numpy as np
from tqdm import tqdm
import SimpleITK as sitk
import scipy.ndimage as ndimage

import parameter as para


if os.path.exists(para.training_set_path):
    shutil.rmtree(para.training_set_path)

new_ct_path = os.path.join(para.training_set_path, 'ct')
new_seg_dir = os.path.join(para.training_set_path, 'seg')

os.mkdir(para.training_set_path)
os.mkdir(new_ct_path)
os.mkdir(new_seg_dir)

start = time()
for file in tqdm(os.listdir(para.train_ct_path)):

    # 将CT和金标准入读内存
    print(os.path.join(para.train_ct_path, file))
    ct = sitk.ReadImage(os.path.join(para.train_ct_path, file), sitk.sitkInt16)
    ct_array = sitk.GetArrayFromImage(ct)

    seg = sitk.ReadImage(os.path.join(para.train_seg_path, file.replace('volume', 'segmentation')), sitk.sitkUInt8)
    seg_array = sitk.GetArrayFromImage(seg)

    # 将金标准中肝脏和肝肿瘤的标签融合为一个
    seg_array[seg_array > 0] = 1

    # 将灰度值在阈值之外的截断掉
    ct_array[ct_array > para.upper] = para.upper
    ct_array[ct_array < para.lower] = para.lower

    # 对CT数据在横断面上进行降采样,并进行重采样,将所有数据的z轴的spacing调整到1mm
    ct_array = ndimage.zoom(ct_array, (ct.GetSpacing()[-1] / para.slice_thickness, para.down_scale, para.down_scale), order=3)
    seg_array = ndimage.zoom(seg_array, (ct.GetSpacing()[-1] / para.slice_thickness, 1, 1), order=0)

    # 找到肝脏区域开始和结束的slice,并各向外扩张slice
    z = np.any(seg_array, axis=(1, 2))
    start_slice, end_slice = np.where(z)[0][[0, -1]]

    # 两个方向上各扩张slice
    start_slice = max(0, start_slice - para.expand_slice)
    end_slice = min(seg_array.shape[0] - 1, end_slice + para.expand_slice)

    # 如果这时候剩下的slice数量不足size,直接放弃该数据,这样的数据很少,所以不用担心
    if end_slice - start_slice + 1 < para.size:
        print('!!!!!!!!!!!!!!!!')
        print(file, 'have too little slice', ct_array.shape[0])
        print('!!!!!!!!!!!!!!!!')
        continue

    ct_array = ct_array[start_slice:end_slice + 1, :, :]
    seg_array = seg_array[start_slice:end_slice + 1, :, :]

    # 最终将数据保存为nii
    new_ct = sitk.GetImageFromArray(ct_array)

    new_ct.SetDirection(ct.GetDirection())
    new_ct.SetOrigin(ct.GetOrigin())
    new_ct.SetSpacing((ct.GetSpacing()[0] * int(1 / para.down_scale), ct.GetSpacing()[1] * int(1 / para.down_scale), para.slice_thickness))

    new_seg = sitk.GetImageFromArray(seg_array)

    new_seg.SetDirection(ct.GetDirection())
    new_seg.SetOrigin(ct.GetOrigin())
    new_seg.SetSpacing((ct.GetSpacing()[0], ct.GetSpacing()[1], para.slice_thickness))

    sitk.WriteImage(new_ct, os.path.join(new_ct_path, file))
    sitk.WriteImage(new_seg, os.path.join(new_seg_dir, file.replace('volume', 'segmentation').replace('.nii', '.nii.gz')))

解决思路

通过代码我们可以看到作者是将所有文件名生成一个列表,循环列表来读取每一个CT图像以及它的mask,再进行数据预处理。实际上我们可以将列表分成多个子集,然后每一个子集用一个进程来进行读取与预处理的操作

多进程模板

import multiprocessing as mp

def func(param):
	# DO sth.

if __name__ == '__main__': # Windows系统下一定要有这个!!
	pool = mp.Pool(n) # 调用进程池,输入想要调用的进程数
	param_list = [param1,param2,...]
	for param in param_list:
		#apply_async() 异步非阻塞方法,不用等待当前进程执行完再进行下一个进程
		pool.apply_async(func,args=(param,)) # apply_async()方法的参数传递必须是tuple形式
	pool.close() # 进程池关闭,不再加入新的进程
	pool.join() # 阻塞主进程的执行,等待子进程运行结束再运行主进程后面的代码
	print("Done")

完整代码

import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
import shutil
import time

import numpy as np
import SimpleITK as sitk
import scipy.ndimage as ndimage
import multiprocessing as mp
import parameter as para

if os.path.exists(para.training_set_path):
    shutil.rmtree(para.training_set_path)

new_ct_path = os.path.join(para.training_set_path, 'ct')
new_seg_dir = os.path.join(para.training_set_path, 'seg')


os.makedirs(para.training_set_path,exist_ok=True)
os.makedirs(new_ct_path,exist_ok=True)
os.makedirs(new_seg_dir,exist_ok=True)

def data_preprocess(data_list:list):
    for file in data_list:
        # 将CT和金标准入读内存
        ct = sitk.ReadImage(os.path.join(para.train_ct_path, file), sitk.sitkInt16)
        ct_array = sitk.GetArrayFromImage(ct)

        seg = sitk.ReadImage(os.path.join(para.train_seg_path, file.replace('volume', 'segmentation')), sitk.sitkUInt8)
        seg_array = sitk.GetArrayFromImage(seg)

        # 将金标准中肝脏和肝肿瘤的标签融合为一个
        seg_array[seg_array > 0] = 1

        # 将灰度值在阈值之外的截断掉
        ct_array[ct_array > para.upper] = para.upper
        ct_array[ct_array < para.lower] = para.lower

        # 对CT数据在横断面上进行降采样,并进行重采样,将所有数据的z轴的spacing调整到1mm
        ct_array = ndimage.zoom(ct_array, (ct.GetSpacing()[-1] / para.slice_thickness, para.down_scale, para.down_scale), order=3)
        seg_array = ndimage.zoom(seg_array, (ct.GetSpacing()[-1] / para.slice_thickness, 1, 1), order=0)

        # 找到肝脏区域开始和结束的slice,并各向外扩张slice
        z = np.any(seg_array, axis=(1, 2))
        start_slice, end_slice = np.where(z)[0][[0, -1]]

        # 两个方向上各扩张slice
        start_slice = max(0, start_slice - para.expand_slice)
        end_slice = min(seg_array.shape[0] - 1, end_slice + para.expand_slice)

        # 如果这时候剩下的slice数量不足size,直接放弃该数据,这样的数据很少,所以不用担心
        if end_slice - start_slice + 1 < para.size:
            print('!!!!!!!!!!!!!!!!')
            print(file, 'have too little slice', ct_array.shape[0])
            print('!!!!!!!!!!!!!!!!')
            continue

        ct_array = ct_array[start_slice:end_slice + 1, :, :]
        seg_array = seg_array[start_slice:end_slice + 1, :, :]

        # 最终将数据保存为nii
        new_ct = sitk.GetImageFromArray(ct_array)

        new_ct.SetDirection(ct.GetDirection())
        new_ct.SetOrigin(ct.GetOrigin())
        new_ct.SetSpacing((ct.GetSpacing()[0] * int(1 / para.down_scale), ct.GetSpacing()[1] * int(1 / para.down_scale), para.slice_thickness))

        new_seg = sitk.GetImageFromArray(seg_array)

        new_seg.SetDirection(ct.GetDirection())
        new_seg.SetOrigin(ct.GetOrigin())
        new_seg.SetSpacing((ct.GetSpacing()[0], ct.GetSpacing()[1], para.slice_thickness))

        sitk.WriteImage(new_ct, os.path.join(new_ct_path, file))
        sitk.WriteImage(new_seg, os.path.join(new_seg_dir, file.replace('volume', 'segmentation').replace('.nii', '.nii.gz')))

if __name__ == '__main__':
    train_ct_list = os.listdir(para.train_ct_path)
    
    # 将train_ct_list分成4份
    nums = len(train_ct_list)//4 
    data_list = []
    for i in range(3):
        data_list.append(train_ct_list[i*nums:(i+1)*nums])
    data_list.append(train_ct_list[3*nums:])

    begin = time.time()
    pool = mp.Pool(4) # 4为进程数,可以调用multiprocessing.cpu_count()来查看cpu数量
	for paths in data_list:
		pool.apply_async(data_preprocess,args=(paths,))
	pool.close()
	pool.join()
    print("处理用时:{:.3f}s".format(time.time()-begin))

实现结果

在测试集20套数据上进行测试,调用4个进程耗时约在50 ~ 60s之间,换成111套数据的训练集实测20 ~ 25分钟左右,数据量增加后速度有所下降,但还是比单一进程处理要快很多。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值