直接调库还是YYDS
import SimpleITK as sitk
import numpy as np
image1_path = './data/100408_seg/mask_100408_arssr.nii.gz'
image2_path = './data/origin_seg/mask_100408.nii.gz'
image1_itk = sitk.ReadImage(image1_path, sitk.sitkUInt8)
image2_itk = sitk.ReadImage(image2_path, sitk.sitkUInt8)
image1 = sitk.GetArrayFromImage(image1_itk)
image2 = sitk.GetArrayFromImage(image2_itk)
# 二值化区域,原值为2的设为1,其它为0
image1 = np.where(image1 == 2, 1, 0)
image2 = np.where(image2 == 2, 1, 0)
image1 = sitk.GetImageFromArray(image1)
image2 = sitk.GetImageFromArray(image2)
overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
overlap_measures_filter.Execute(image1, image2)
# Dice
dice_similarity = overlap_measures_filter.GetDiceCoefficient()
print(dice_similarity)
再给一种方案
import numpy as np
import nibabel as nib
def cal_subject_level_dice(prediction, target, class_num=7):# class_num是你分割的目标的类别个数
'''
step1: calculate the dice of each category
step2: remove the dice of the empty category and background, and then calculate the mean of the remaining dices.
:param prediction: the automated segmentation result, a numpy array with shape of (h, w, d)
:param target: the ground truth mask, a numpy array with shape of (h, w, d)
:param class_num: total number of categories
:return:
'''
eps = 1e-10
empty_value = -1.0
dscs = empty_value * np.ones((class_num), dtype=np.float32)
for i in range(0, class_num):
if i not in target and i not in prediction:
continue
target_per_class = np.where(target == i, 1, 0).astype(np.float32)
prediction_per_class = np.where(prediction == i, 1, 0).astype(np.float32)
tp = np.sum(prediction_per_class * target_per_class)
fp = np.sum(prediction_per_class) - tp
fn = np.sum(target_per_class) - tp
dsc = 2 * tp / (2 * tp + fp + fn + eps)
dscs[i] = dsc
dscs = np.where(dscs == -1.0, np.nan, dscs)
print("dscs:", dscs)
# np.nanmean(dscs[1:])
subject_level_dice = dscs[1:]
return subject_level_dice
def evaluate_demo(prediction_nii_files, target_nii_files):
'''
This is a demo for calculating the mean dice of all subjects.
:param prediction_nii_files: a list which contains the .nii file paths of predicted segmentation
:param target_nii_files: a list which contains the .nii file paths of ground truth mask
:return:
'''
dscs = []
for prediction_nii_file, target_nii_file in zip(prediction_nii_files, target_nii_files):
prediction_nii = nib.load(prediction_nii_file)
prediction = prediction_nii.get_fdata()
target_nii = nib.load(target_nii_file)
target = target_nii.get_fdata()
dsc = cal_subject_level_dice(prediction, target, class_num=7)
dscs.append(dsc)
# np.mean(dscs)
return dscs
dice = evaluate_demo(['./data/100408_seg/mask_100408_arssr.nii.gz'],['./data/origin_seg/mask_100408.nii.gz'])#前一个地址是你预测的三维图像的地址,后一个是标签地址
print(dice)