import os
import glob
import SimpleITK as sitk
import numpy as np
import vtk
from vtkmodules.util import numpy_support
import cv2

# 设置路径
slice_image_folder = R'path_to_imgs'
mask_image_folder = R'path_to_masks'

# 获取所有图像文件路径
slice_image_files = sorted(glob.glob(os.path.join(slice_image_folder, '*.jpg')))
mask_image_files = sorted(glob.glob(os.path.join(mask_image_folder, '*.png')))


# 读取图像和掩码,并将其堆叠为3D数组
def read_images(slice_files, mask_files):
    slices = [cv2.imread(file, cv2.IMREAD_GRAYSCALE) for file in slice_files]
    masks = [cv2.imread(file, cv2.IMREAD_GRAYSCALE) for file in mask_files]
    if len(slices) != len(masks):
        raise ValueError("Number of slice images and mask images must be the same.")
    slices_array = np.stack(slices, axis=2)

    pre_mask = []
    for m in masks:
        m[m == 50] = 0
        _, binary_mask = cv2.threshold(m, 1, 255, cv2.THRESH_BINARY)
        pre_mask.append(binary_mask)
    masks_array = np.stack(pre_mask, axis=2)
    return slices_array, masks_array


# 将numpy数组转换为SimpleITK图像
def numpy_to_sitk(image_array):
    sitk_image = sitk.GetImageFromArray(image_array)
    return sitk_image


# 3D可视化
def plot_3d(image_array, mask_array):
    try:
        # 创建VTK图像数据对象
        vtk_image = vtk.vtkImageData()
        vtk_image.SetDimensions(image_array.shape)
        vtk_image.SetSpacing((1.0, 1.0, 1.0))

        # 将numpy数组转换为VTK数组
        flat_image_array = image_array.ravel(order='F')
        vtk_image_array = numpy_support.numpy_to_vtk(num_array=flat_image_array, deep=True, array_type=vtk.VTK_FLOAT)
        vtk_image.GetPointData().SetScalars(vtk_image_array)

        # 将掩码应用于VTK图像
        flat_mask_array = mask_array.ravel(order='F')
        vtk_mask_array = numpy_support.numpy_to_vtk(num_array=flat_mask_array, deep=True, array_type=vtk.VTK_FLOAT)
        vtk_mask_image = vtk.vtkImageData()
        vtk_mask_image.SetDimensions(mask_array.shape)
        vtk_mask_image.SetSpacing((1.0, 1.0, 1.0))
        vtk_mask_image.GetPointData().SetScalars(vtk_mask_array)

        # 创建颜色映射器
        colorFunc = vtk.vtkColorTransferFunction()
        colorFunc.AddRGBPoint(0, 0.0, 0.0, 0.0)
        colorFunc.AddRGBPoint(255, 1.0, 1.0, 1.0)

        opacityFunc = vtk.vtkPiecewiseFunction()
        opacityFunc.AddPoint(0, 0.0)
        opacityFunc.AddPoint(255, 1.0)

        # 创建渲染器
        renderer = vtk.vtkRenderer()

        # 图像映射
        volumeMapper = vtk.vtkSmartVolumeMapper()
        volumeMapper.SetInputData(vtk_image)

        volumeProperty = vtk.vtkVolumeProperty()
        volumeProperty.SetColor(colorFunc)
        volumeProperty.SetScalarOpacity(opacityFunc)
        volumeProperty.ShadeOn()
        volumeProperty.SetInterpolationTypeToLinear()

        volume = vtk.vtkVolume()
        volume.SetMapper(volumeMapper)
        volume.SetProperty(volumeProperty)

        renderer.AddVolume(volume)

        # 掩码映射
        maskMapper = vtk.vtkSmartVolumeMapper()
        maskMapper.SetInputData(vtk_mask_image)

        maskOpacityFunc = vtk.vtkPiecewiseFunction()
        maskOpacityFunc.AddPoint(0, 0.0)
        maskOpacityFunc.AddPoint(1, 1.0)

        maskColorFunc = vtk.vtkColorTransferFunction()
        maskColorFunc.AddRGBPoint(0, 0.0, 0.0, 0.0)
        maskColorFunc.AddRGBPoint(1, 1.0, 0.0, 0.0)

        maskProperty = vtk.vtkVolumeProperty()
        maskProperty.SetColor(maskColorFunc)
        maskProperty.SetScalarOpacity(maskOpacityFunc)
        maskProperty.ShadeOn()
        maskProperty.SetInterpolationTypeToLinear()

        maskVolume = vtk.vtkVolume()
        maskVolume.SetMapper(maskMapper)
        maskVolume.SetProperty(maskProperty)

        renderer.AddVolume(maskVolume)

        # 创建渲染窗口
        renderWindow = vtk.vtkRenderWindow()
        renderWindow.AddRenderer(renderer)

        # 创建交互式渲染窗口
        renderWindowInteractor = vtk.vtkRenderWindowInteractor()
        renderWindowInteractor.SetRenderWindow(renderWindow)

        # 开始渲染
        renderWindow.Render()
        renderWindowInteractor.Start()
    except Exception as e:
        print(f"An error occurred: {e}")


# 主函数
def main():
    try:
        slices_array, masks_array = read_images(slice_image_files, mask_image_files)
        plot_3d(slices_array, masks_array)
    except Exception as e:
        print(f"An error occurred in main function: {e}")


if __name__ == '__main__':
    main()
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.
  • 102.
  • 103.
  • 104.
  • 105.
  • 106.
  • 107.
  • 108.
  • 109.
  • 110.
  • 111.
  • 112.
  • 113.
  • 114.
  • 115.
  • 116.
  • 117.
  • 118.
  • 119.
  • 120.
  • 121.
  • 122.
  • 123.
  • 124.
  • 125.
  • 126.
  • 127.
  • 128.
  • 129.
  • 130.
  • 131.
  • 132.
  • 133.
  • 134.
  • 135.
  • 136.
  • 137.
  • 138.
  • 139.