这是写出来对训练好的模型进行前向推理的代码,最终将预测结果输出保存在病例文件夹内。
import os
import torch
import SimpleITK as sitk
import numpy as np
from monai.networks.nets import UNet, DynUNet
from monai.transforms import (
AsDiscrete,
AddChannel,
Compose,
ToTensor,
SqueezeDim,
Activations,
Resize,
Spacing,
ScaleIntensity,
KeepLargestConnectedComponent,
NormalizeIntensity,
Orientation,
)
root_dir = r"D:\test" # 包含所有病例的目录
model_dir = r"C:\Users\fffyyy\Desktop\WG_segmentation" # 模型所在路径
model_name = r"dyn_model_epoch200.pth" # 模型名称
save_name = "pred.nii.gz" # 分割结果保存的文件名
device = torch.device("cuda:0")
in_channels = 1
n_class = 2
spatial_size = [128, 128, 32]
spacing = [0.62, 0.62, 3.2]
module_dir = model_dir
kernels = []
strides = []
sizes, spacings = spatial_size, spacing
while True:
spacing_ratio = [sp / min(spacings) for sp in spacings]
stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
if all(s == 1 for s in stride):
break
sizes = [i / j for i, j in zip(sizes, stride)]
spacings = [i * j for i, j in zip(spacings, stride)]
kernels.append(kernel)
strides.append(stride)
strides.insert(0, len(spacings) * [1])
kernels.append(len(spacings) * [3])
model = DynUNet(
spatial_dims=3,
in_channels=in_channels,
out_channels=n_class,
kernel_size=kernels,
strides=strides,
upsample_kernel_size=strides[1:],
norm_name="instance",
deep_supervision=True,
deep_supr_num=2,
res_block=False,
).to(device)
model.load_state_dict(torch.load(os.path.join(module_dir, model_name)))
model.eval()
# 加载数据的变换
keys = ["image", "label"]
load_val_transforms = Compose(
[
# LoadImage(),
AddChannel(),
# Orientation(axcodes="RAS"),
# CropForeground(select_fn=lambda x: x >= 1, margin=0),
Spacing(
pixdim=[0.62, 0.62, 3.2],
mode=("bilinear"),
),
Resize(spatial_size=(128, 128, 32), mode='trilinear'),
ScaleIntensity(),
# NormalizeIntensity(nonzero=False, channel_wise=True),
ToTensor(),
]
)
# # 数据变换实例化
addchannel = AddChannel()
squeeze = SqueezeDim()
spacing = Spacing(pixdim=[0.62, 0.62, 3.2], mode=("bilinear"))
resize = Resize(spatial_size=(128, 128, 32), mode='trilinear')
orientation = Orientation(axcodes="RAS")
scale_intensity = ScaleIntensity()
normalizeIntensity = NormalizeIntensity(nonzero=False, channel_wise=True)
keep_largest_connected_component = KeepLargestConnectedComponent(applied_labels=1)
to_tensor = ToTensor()
post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2) # 输出图像后处理实例化
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True, n_classes=2), KeepLargestConnectedComponent(applied_labels=1)])
with torch.no_grad():
for root, sub_dir, files in os.walk(root_dir):
if not sub_dir:
# print(root)
image_t2 = sitk.ReadImage(os.path.join(root, "T2WI.nii.gz"))
# print(image_t2.GetSize())
image_t2_arr = sitk.GetArrayFromImage(image_t2).T
# 数据预处理
image_t2_arr = addchannel(image_t2_arr)
image_t2_arr = image_t2_arr.astype(np.float)
image_t2_arr = resize(image_t2_arr)
image_t2_arr = scale_intensity(image_t2_arr)
trans_image = to_tensor(image_t2_arr).to(device)
trans_image = addchannel(trans_image)
val_outputs = model(trans_image)
val_outputs = val_outputs.cpu()
val_outputs = squeeze(val_outputs)
resize2 = Resize(spatial_size=image_t2.GetSize(), mode="trilinear", align_corners=True)
val_outputs = resize2(val_outputs)
val_outputs = addchannel(val_outputs)
val_outputs = post_trans(torch.from_numpy(val_outputs))
val_outputs = squeeze(val_outputs)
# val_outputs = torch.from_numpy(val_outputs)
show_out = torch.argmax(val_outputs, dim=0).detach().cpu() # [0, :, :, 8]
output_array = show_out.numpy()
output_array = output_array.T
output_array = output_array.astype(np.float32)
output_save = sitk.GetImageFromArray(output_array)
# print(output_save.GetSize())
output_save.CopyInformation(image_t2)
sitk.WriteImage(output_save, os.path.join(root, save_name))
patient_name = root.split("\\")[-1]
print(f"{patient_name} done")