python 常用函数和自定义函数整理

以下函数主要用于记录和保存,方便自己查阅。
---------持续更新

1. 3D 图像处理

def numpy2sitk(arr, sitk_ori_img):
    # numpy转换为sitk
    sitk_img = sitk.GetImageFromArray(arr)
    sitk_img.SetOrigin(sitk_ori_img.GetOrigin())
    sitk_img.SetSpacing(sitk_ori_img.GetSpacing())
    sitk_img.SetDirection(sitk_ori_img.GetDirection())
    return sitk_img
import dicom2nifti
import os
import pydicom as pdic
 # dicom convert to nii file
 # method 1
def dicom2nii(dicom_path, save_root): 
	for patient in os.listdir(dicom_path):
		    new_path = dicom_path + patient +'/'
		    os.makedirs(save_root + patient +'/',exist_ok=True)
		    dicom2nifti.convert_directory(new_path,save_root + patient +'/')
		    create_name = os.listdir(save_root + patient)[0]
		    print(create_name)
		    os.rename(save_root + patient +f'/{create_name}', save_root + patient + f'/{patient}.nii.gz')
		    
# method 2
def dcm2nii(dcm_path, nii_path):
	    print(dcm_path)
	    dcm = pdic.read_file(dcm_path)
	    if 'ImagerPixelSpacing' in dcm:
	       	 	sp_2d = dcm.ImagerPixelSpacing
	    else:
	        	sp_2d = (0.278875, 0.278875)
	    arr = dcm.pixel_array
	    img = sitk.GetImageFromArray(arr)
	    img.SetSpacing((sp_2d[0], sp_2d[1], 1))
	    sitk.WriteImage(img, nii_path)
def resample_img(itk_image, out_spacing=[2.0, 2.0, 2.0], is_label=False):
    # resample images to 2mm spacing with simple itk

    original_spacing = itk_image.GetSpacing()
    original_size = itk_image.GetSize()

    out_size = [
        int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))),
        int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))),
        int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size)
    resample.SetOutputDirection(itk_image.GetDirection())
    resample.SetOutputOrigin(itk_image.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(itk_image.GetPixelIDValue())

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkBSpline)

    return resample.Execute(itk_image)
def roi_extract(img_path, lab_path, patch_size):
    """从目标区域中心处扩展成patch_size大小的区域

    Args:
        img_path (str): 图像路径
        lab_path (str): 标签路径
        patch_size (int): 扩展的目标尺寸

    Returns:
        sitk image: 输出sitk数据
    """    
    img = sitk.ReadImage(img_path)
    lab = sitk.ReadImage(lab_path)
    img_arr = sitk.GetArrayFromImage(img)
    lab_arr = sitk.GetArrayFromImage(lab)
    expand_slice = 0
    nonzero_list = []
    for ind in [(1,2),(0,2),(0,1)]:
        nonzero_ind = np.any(lab_arr, axis=ind)
        start_slice, end_slice = np.where(nonzero_ind)[0][[0, -1]] # [  0, 605]
        nonzero_list.append([start_slice, end_slice])

    for ind, slice in enumerate(nonzero_list):
        slice_diff = slice[1]-slice[0]
        expand_slice = (patch_size-slice_diff)//2
        if slice[0] < expand_slice:
            slice[0] = 0
            slice[1] += 2*expand_slice
        elif slice[1] + expand_slice >= lab_arr.shape[ind]:
            slice[1] = lab_arr.shape[ind] - 1
            slice[0] -= 2*expand_slice
        else:
            slice[0] -= expand_slice
            slice[1] += expand_slice
        # 补全
        if slice[1]-slice[0] > patch_size:
            slice[0] += 1
        if slice[1]-slice[0] < patch_size:
            slice[1] += 1
    new_img_arr = img_arr[nonzero_list[0][0]:nonzero_list[0][1], nonzero_list[1][0]:nonzero_list[1][1], nonzero_list[2][0]:nonzero_list[2][1]]
    new_lab_arr = lab_arr[nonzero_list[0][0]:nonzero_list[0][1], nonzero_list[1][0]:nonzero_list[1][1], nonzero_list[2][0]:nonzero_list[2][1]] 
    
    sitk_img = numpy2sitk(new_img_arr, img)
    sitk_lab = numpy2sitk(new_lab_arr, lab)

    return sitk_img, sitk_lab
def slice_slide_crop(img_path, lab_path, patch_num, patch_size):
    """朝3D数据的切片方向切patch_num+1个patch_size尺寸的Patch

    Args:
        img_path (str): 图像路径
        lab_path (str): 标签路径
        patch_num (int): patch数量
        patch_size (int): patch尺寸

    Returns:
        sitk: 输出sitk数据并保存
    """    
    img = sitk.ReadImage(img_path)
    lab = sitk.ReadImage(lab_path)
    if patch_num*patch_size > img.GetSize()[2]:
        stride = (img.GetSize()[2]-patch_size)//patch_num
        start_slice, end_slice = 0, patch_size
        for i in range(patch_num+1):
            crop_slice_ind = [start_slice + i*stride, end_slice + i*stride]
            new_img = img[:,:,crop_slice_ind[0]:crop_slice_ind[1]]
            new_lab = lab[:,:,crop_slice_ind[0]:crop_slice_ind[1]]
            sitk.WriteImage(new_img,f'0001_0001/crop_img/0001_0001_s{i}.nii.gz')
            sitk.WriteImage(new_lab,f'0001_0001/crop_lab/0001_0001_s{i}.nii.gz')
        return True
    else:
        print('patch_size/patch_num is too small')
        return False

2. Pytorch相关处理

# 模型加载
from collections import OrderedDict
def load_distributed_model(model_path):
	# 加载并行训练的模型参数
	device = torch.device("cuda")
	model = DGCNN().to(device)  #自己的模型
	state_dict = torch.load(model_path)    #存放模型的位置
	new_state_dict = OrderedDict()
	for k, v in state_dict.items():
    	name = k[7:] # remove `module.`
    	new_state_dict[name] = v
    # load params
	model.load_state_dict(new_state_dict)
	return model

# 多类分割dice损失
def generalized_dice_loss(pred, target):
    """compute the weighted dice_loss
    Args:
        pred (tensor): prediction after softmax, shape(bath_size, channels, height, width)
        target (tensor): gt, shape(bath_size, channels, height, width)
    Returns:
        gldice_loss: loss value
    """    
    wei = torch.sum(target, axis=[0,2,3]) # (n_class,)
    wei = 1/(wei**2+epsilon)
    intersection = torch.sum(wei*torch.sum(pred * target, axis=[0,2,3]))
    union = torch.sum(wei*torch.sum(pred + target, axis=[0,2,3]))
    gldice_loss = 1 - (2. * intersection) / (union + epsilon)
    return gldice_loss

3. python 相关处理

# sitk.ReadImage 报错 正交坐标系,重新写入sfrom/qfrom信息
def get_tips_bifs(skele):
    """get tips and bifs from skeleon image by the number of pixle's neighbor

    Args:
        skele (np.array): skeleton array of the binary image 

    Returns:
        tips, bifs: list of tips and bifs
    """    
    kernel = np.ones((3, 3))
    map = ndimage.convolve(skele, kernel)
    map = map * skele
    tips_idx = np.where(map == 2) # including the center of the pixel
    tips = np.stack(tips_idx, axis=1) # !!! stack the coordinate of 'where' function
    bifs = get_bifs_with_single_pixel(map)
    return tips, bifs


def get_bifs_with_single_pixel(map):
    """one bif point may have more than one pixel, so choose the representative one as the bif pixel

    Args:
        map (np.array): pixle's neighbor

    Returns:
        bifs: get one pixel of the bif point
    """    

    map[map < 4] = 0 # >=4 is a bif point
    map[map > 0] = 1
    lab = label(map)
    props = regionprops(lab)
    bifs = []
    for n in range(len(props)):
        ipdb.set_trace()
        temp = props[n].coords # get the coordinates of the region
        x = np.argmax(np.bincount(temp[:, 0])) # !!! get the common coordinates(众数)
        y = np.argmax(np.bincount(temp[:, 1]))
        bifs.append([x, y])
    return np.array(bifs)

4. 错误情况

# sitk.ReadImage 报错 正交坐标系,重新写入sfrom/qfrom信息
for patient in os.listdir(root):
    new_path = root+patient
    img = nib.load(new_path + f'/{patient}.nii.gz')
    qform = img.get_qform()
    img.set_qform(qform)
    sfrom = img.get_sform()
    img.set_sform(sfrom)
    nib.save(img, new_path + f'/{patient}.nii.gz')
#  修改 spacing mm单位为cm单位
def spacing_adjust(nii_data):

    ori_spacing = nii_data.GetSpacing()
    ori_origin = nii_data.GetOrigin()
    nii_data.SetSpacing([i/10. for i in ori_spacing])
    nii_data.SetOrigin([i/10. for i in ori_origin])
    return nii_data

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值