dyn-unet对训练好的模型进行推理输出结果

这是写出来对训练好的模型进行前向推理的代码,最终将预测结果输出保存在病例文件夹内。

import os
import torch
import SimpleITK as sitk
import numpy as np
from monai.networks.nets import UNet, DynUNet
from monai.transforms import (
    AsDiscrete,
    AddChannel,
    Compose,
    ToTensor,
    SqueezeDim,
    Activations,
    Resize,
    Spacing,
    ScaleIntensity,
    KeepLargestConnectedComponent,
    NormalizeIntensity,
    Orientation,
)


root_dir = r"D:\test"  # 包含所有病例的目录
model_dir = r"C:\Users\fffyyy\Desktop\WG_segmentation"  # 模型所在路径
model_name = r"dyn_model_epoch200.pth"   # 模型名称

save_name = "pred.nii.gz"  # 分割结果保存的文件名

device = torch.device("cuda:0")


in_channels = 1
n_class = 2
spatial_size = [128, 128, 32]
spacing = [0.62, 0.62, 3.2]
module_dir = model_dir
kernels = []
strides = []
sizes, spacings = spatial_size, spacing

while True:
    spacing_ratio = [sp / min(spacings) for sp in spacings]
    stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
    kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
    if all(s == 1 for s in stride):
        break
    sizes = [i / j for i, j in zip(sizes, stride)]
    spacings = [i * j for i, j in zip(spacings, stride)]
    kernels.append(kernel)
    strides.append(stride)
strides.insert(0, len(spacings) * [1])
kernels.append(len(spacings) * [3])

model = DynUNet(
    spatial_dims=3,
    in_channels=in_channels,
    out_channels=n_class,
    kernel_size=kernels,
    strides=strides,
    upsample_kernel_size=strides[1:],
    norm_name="instance",
    deep_supervision=True,
    deep_supr_num=2,
    res_block=False,
).to(device)
model.load_state_dict(torch.load(os.path.join(module_dir, model_name)))
model.eval()


# 加载数据的变换
keys = ["image", "label"]
load_val_transforms = Compose(
    [
        # LoadImage(),
        AddChannel(),
        # Orientation(axcodes="RAS"),
        # CropForeground(select_fn=lambda x: x >= 1, margin=0),
        Spacing(
            pixdim=[0.62, 0.62, 3.2],
            mode=("bilinear"),
        ),
        Resize(spatial_size=(128, 128, 32), mode='trilinear'),
        ScaleIntensity(),
        # NormalizeIntensity(nonzero=False, channel_wise=True),
        ToTensor(),
    ]
)
# # 数据变换实例化
addchannel = AddChannel()
squeeze = SqueezeDim()
spacing = Spacing(pixdim=[0.62, 0.62, 3.2], mode=("bilinear"))
resize = Resize(spatial_size=(128, 128, 32), mode='trilinear')
orientation = Orientation(axcodes="RAS")
scale_intensity = ScaleIntensity()
normalizeIntensity = NormalizeIntensity(nonzero=False, channel_wise=True)
keep_largest_connected_component = KeepLargestConnectedComponent(applied_labels=1)
to_tensor = ToTensor()

post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)  # 输出图像后处理实例化
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True, n_classes=2), KeepLargestConnectedComponent(applied_labels=1)])
with torch.no_grad():
    for root, sub_dir, files in os.walk(root_dir):

        if not sub_dir:
            # print(root)
            image_t2 = sitk.ReadImage(os.path.join(root, "T2WI.nii.gz"))
            # print(image_t2.GetSize())
            image_t2_arr = sitk.GetArrayFromImage(image_t2).T
            # 数据预处理

            image_t2_arr = addchannel(image_t2_arr)
            image_t2_arr = image_t2_arr.astype(np.float)
            image_t2_arr = resize(image_t2_arr)
            image_t2_arr = scale_intensity(image_t2_arr)
            trans_image = to_tensor(image_t2_arr).to(device)
            trans_image = addchannel(trans_image)
            
            val_outputs = model(trans_image)
            val_outputs = val_outputs.cpu()

            val_outputs = squeeze(val_outputs)
            resize2 = Resize(spatial_size=image_t2.GetSize(), mode="trilinear", align_corners=True)
            val_outputs = resize2(val_outputs)
            val_outputs = addchannel(val_outputs)
            val_outputs = post_trans(torch.from_numpy(val_outputs))
            val_outputs = squeeze(val_outputs)

            # val_outputs = torch.from_numpy(val_outputs)
            show_out = torch.argmax(val_outputs, dim=0).detach().cpu()  # [0, :, :, 8]
            output_array = show_out.numpy()
            output_array = output_array.T
            output_array = output_array.astype(np.float32)
            output_save = sitk.GetImageFromArray(output_array)
            # print(output_save.GetSize())
            output_save.CopyInformation(image_t2)
            sitk.WriteImage(output_save, os.path.join(root, save_name))
            patient_name = root.split("\\")[-1]
            print(f"{patient_name} done")









  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值