代码实现如下功能:
1、根据最大连通域中心,进行粗对齐(因为原始图像位置相差较远)
2、刚性配准
3、根据自己定义的相似度函数再次微调
import numpy as np
from scipy import ndimage
import nibabel as nib
import SimpleITK as sitk
import random
def find_largest_connected_component_center(mask):
# 计算连通域
labeled_mask, num_features = ndimage.label(mask)
if num_features == 0:
return None
# 计算每个连通域的大小
sizes = ndimage.sum(mask, labeled_mask, range(1, num_features+1))
# 找到最大的连通域
max_label = np.argmax(sizes) + 1
# 计算最大连通域的中心坐标
indices = np.indices(mask.shape)
x_center = np.mean(indices[0][labeled_mask == max_label])
y_center = np.mean(indices[1][labeled_mask == max_label])
z_center = np.mean(indices[2][labeled_mask == max_label])
return (x_center, y_center, z_center)
RefImg = 'D:/test/mask_4.nii.gz'
MovImg = 'D:/test/maskOri.nii'
TranslatedImg ='D:/test/mask44_moved.nii' #粗对齐结果
RegImg ='D:/test/registered_image.nii' #粗对齐结果
OutImg ='D:/test/Out_image.nii' #粗对齐结果
maskMove = nib.load(MovImg)
maskRef = nib.load(RefImg)
Movedata =maskMove.get_fdata()
header = maskMove.header
Movedata= np.where(Movedata != 0, 1, 0)
Refdata = maskRef.get_fdata()
Refdata[(Refdata != 1) & (Refdata != 4)] = 0
Refdata[Refdata == 4] = 1
[x_center0, y_center0, z_center0] = find_largest_connected_component_center(Refdata)
[x_center1, y_center1, z_center1] = find_largest_connected_component_center(Movedata)
transformx = x_center1 - x_center0
transformy = y_center1 - y_center0
transformz = z_center1 - z_center0
translation = [-transformx,-transformy,-transformz]
#粗略对齐
nii_data_translated = np.roll(Movedata, np.round(translation).astype(int), axis=(0, 1, 2))
nii_img = nib.Nifti1Image(nii_data_translated, affine=None)
nib.save(nii_img, TranslatedImg)
# Convert the nii images to SimpleITK images
sitk_Refimg = sitk.ReadImage(RefImg)
sitk_Refimg = sitk.Cast(sitk_Refimg, sitk.sitkFloat32)
sitk_Movimg = sitk.ReadImage(TranslatedImg)
sitk_Movimg = sitk.Cast(sitk_Movimg, sitk.sitkFloat32)
# Create a rigid registration object
rigid_registration = sitk.ImageRegistrationMethod()
# Set the similarity metric to mean squares
rigid_registration.SetMetricAsMeanSquares()
# Set the optimizer to gradient descent
rigid_registration.SetOptimizerAsGradientDescent(learningRate=10.0, numberOfIterations=1000, convergenceMinimumValue=1e-6, convergenceWindowSize=100)
# Set the interpolator to linear
rigid_registration.SetInterpolator(sitk.sitkLinear)
centering_transform = sitk.CenteredTransformInitializer(sitk_Refimg, sitk_Movimg, sitk.Euler3DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY)
rigid_registration.SetInitialTransform(centering_transform)
# Execute the registration
final_transform = rigid_registration.Execute(sitk_Refimg, sitk_Movimg)
resampled_img = sitk.Resample(sitk_Movimg, sitk_Refimg, final_transform, sitk.sitkLinear, 0.0, sitk_Movimg.GetPixelID())
sitk.WriteImage(resampled_img, RegImg)
maskRegMove = nib.load(RegImg)
MoveRegdata =maskRegMove.get_fdata()
# 获取mask0的形状和旋转中心
shape = maskRegMove.shape
center = np.array(shape) // 2
# 定义旋转和平移范围
rotate_range = (-10, 10)
translate_range = (-10, 10)
# 定义重合度阈值
threshold = 0.7
# 计算mask0和mask1的重合度
def overlap(mask0, mask1):
overlap = np.sum(mask0 * mask1)
total = np.sum(mask1)
return overlap / total
# 随机平移和旋转mask0
bTrue = True
while bTrue:
# 随机平移
translate = np.array([random.randint(*translate_range) for _ in range(3)])
# 随机旋转
angles = np.array([random.uniform(*rotate_range) for _ in range(3)])
# 进行平移和旋转操作
nii_data_translated = np.roll(MoveRegdata, np.round(translate).astype(int), axis=(0, 1, 2))
# nii_data_rotated = np.rot90(nii_data_translated, k=angles[0]/90, axes=(0, 1))# 绕x轴旋转
if overlap(nii_data_translated,Refdata)>threshold:
bTrue = False
nii_img = nib.Nifti1Image(nii_data_translated, affine=None)
nib.save(nii_img, OutImg)