FastSurfer官方发布的Tutorial代码是Colab的,本文将介绍如何在本地服务器运行FastSurfer快速分割
第一步:run_prediction.py
E:\Anaconda\envs\fastsurfer\python.exe D:\FastSurfer-dev\FastSurferCNN\run_prediction.py --t1 "'D:/FastSurfer-dev/Tutorial/140_orig.mgz'"
One of the following three options has to be passed --in_dir, --csv_file or --t1 with an absolute file path. Please specify the data input directory, the subject list file or the full path to input volume
[INFO: run_prediction.py: 652]: Checking or downloading default checkpoints ...
Process finished with exit code 1
需要Edit Configuration设置运行参数
--t1 D:/FastSurfer-dev/dataset/140_orig.mgz --sd D:/FastSurfer-dev/output/
运行结果是会在output文件夹下生成一个文件夹140_orig.mgz下面会有
D:\FastSurfer-dev\output\140_orig.mgz\mri\aparc.DKTatlas+aseg.deep.mgz
第二步:新建一个python script,mgz转换成nii,并运行Tutorial里的plot_predictions
import os
import sys
from os.path import exists, join, basename, splitext
img = 'D:/FastSurfer-dev/dataset/vs_gk_10vs_gk_t1_refT1_x2_SR.mgz'
output = 'D:/FastSurfer-dev/output/vs_gk_10vs_gk_t1_refT1_x2_SR/mri/aparc.DKTatlas+aseg.deep.mgz'
import nibabel as nib
# conversion to nifti
data = nib.load('D:/FastSurfer-dev/output/vs_gk_10vs_gk_t1_refT1_x2_SR/mri/aparc.DKTatlas+aseg.deep.mgz')
img_nifti = nib.Nifti1Image(data.get_fdata(), data.affine, header=nib.Nifti1Header())
nib.nifti1.save(img_nifti, 'D:/FastSurfer-dev/output/vs_gk_10vs_gk_t1_refT1_x2_SR/mri/aparc.DKTatlas+aseg.deep.nii.gz')
data = nib.load("D:/FastSurfer-dev/dataset/vs_gk_10vs_gk_t1_refT1_x2_SR.mgz")
img_nifti = nib.Nifti1Image(data.get_fdata(), data.affine, header=nib.Nifti1Header())
nib.nifti1.save(img_nifti, "D:/FastSurfer-dev/dataset/vs_gk_10vs_gk_t1_refT1_x2_SR.mgz")
import nibabel as nib
import matplotlib.pyplot as plt
plt.style.use('seaborn-v0_8-whitegrid')
from skimage import color
import torch
import numpy as np
from torchvision import utils
def plot_predictions(image, pred):
"""
Function to plot predictions from validation set.
:param images_batch:
:param labels_batch:
:param batch_output:
:param plt_title:
:param file_save_name:
:return:
"""
f = plt.figure(figsize=(20, 20))
n, h, w = image.shape
mid_slice = n // 2
image = torch.from_numpy(np.expand_dims(image[mid_slice+16:mid_slice+32, :, :], 1))
grid = utils.make_grid(image, nrow=4)
plt.subplot(311)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Slices')
grid2 = utils.make_grid(torch.from_numpy(np.expand_dims(pred[mid_slice+16:mid_slice+32, ...], 1)), nrow=4)[0]
color_grid = color.label2rgb(grid2.numpy(), bg_label=0)
plt.subplot(312)
plt.imshow(color_grid)
plt.title('Prediction')
plt.subplot(313)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.imshow(color_grid, alpha=0.3)
plt.title('Overlay_GT')
for ax in f.axes:
ax.grid(visible=None)
ax.axis("off")
plt.tight_layout()
plt.show()
orig_data = nib.load(img).get_fdata() / 255
pred_data = nib.load('D:/FastSurfer-dev/output/vs_gk_10vs_gk_t1_refT1_x2_SR.mgz/mri/aparc.DKTatlas+aseg.deep.mgz').get_fdata()
plot_predictions(orig_data, pred_data)
查看运行结果