Python SimpleITK 计算医学图像dice

import os
import numpy as np
import codecs
import SimpleITK as sitk
import pandas as pd
import torch

'''

dice.txt context

---len(num_images)
---path of ground truth of image_001
---path of seg_mask of image_001
---path of ground truth of image_002
---path of seg_mask of image_002
    ...
    ...
    ...
'''
def readlines(file):
    """
    read lines by removing '\n' in the end of line
    :param file: a text file
    :return: a list of line strings
    """
    fp = codecs.open(file, 'r', encoding='utf-8')
    linelist = fp.readlines()
    fp.close()
    for i in range(len(linelist)):
        linelist[i] = linelist[i].rstrip('\n') # cancel '\n' per line
    return linelist

def read_test_txt(imlist_file):
    '''
    :param imlist_file: image list file path
    :return: image list path divided into two list
    '''
    lines = readlines(imlist_file)
    num_cases = int(lines[0])

    if (len(lines) - 1) < (num_cases * 2):
        raise ValueError('too few lines in imlist file')
    im_list, seg_list = [], []
    for i in range(num_cases):
        im_path, seg_path = lines[1 + i * 2].strip(), lines[2 + i * 2].strip()
        assert os.path.exists(im_path), 'image not exist: {}'.format(im_path)
        assert os.path.exists(seg_path), 'mask not exist: {}'.format(seg_path)
        im_list.append(im_path)
        seg_list.append(seg_path)

    return im_list, seg_list

def cal_dice(input_tensor, target, num_class, epsilon=1e-6):
    '''
    :params input_tensor:   the result of segmentation
    :params target:         ground true mask
    :params num_class:      label number
    :params epsilon         avoid dividezero arguments
    :return:                each class dice score
    '''
    dice_score = []
    for i in range(1, num_class):
        input_i = (input_tensor == i) * 1
        target_i = (target == i) * 1
        input_i = input_i.view(-1)
        target_i = target_i.view(-1)
        # compute dice score
        intersect = torch.sum(input_i * target_i, 0)
        input_area = torch.sum(input_i, 0)
        target_area = torch.sum(target_i, 0)
        sum_area = input_area + target_area + 2 * epsilon

        dice_score_i = 2 * intersect.float() / sum_area.float()
        dice_score.append(dice_score_i)
        print('class = {}, dice = {}'.format(i, dice_score_i))

    return dice_score

def val(input_path, results_csv):
    if input_path.endswith('txt'):
        gt_list, pre_list = read_test_txt(input_path)
    else:
        raise ValueError('image test_list must either be a txt file or a csv file')
    
    # dice_score_record = pd.DataFrame(columns = ['case_name', 'left_testis', 'right_testis']) # 2 labels
    dice_score_record = pd.DataFrame(columns = ['case_name', 'tumor'])
    for gt_path, pre_path in zip(gt_list, pre_list):
        print('{}: {}'.format(gt_path, pre_path))

        gt_mask = sitk.ReadImage(gt_path)
        pre_mask = sitk.ReadImage(pre_path)
        case_name = pre_path.split('/')[5] # need to change according to where is case_name 
        print(case_name)
        gt_mask_np = sitk.GetArrayFromImage(gt_mask).astype(float)
        pre_mask_np = sitk.GetArrayFromImage(pre_mask).astype(float)
        num_label = np.unique(gt_mask_np)
        num_class = len(num_label)
        # get tensor
        gt_mask = torch.from_numpy(gt_mask_np)
        gt_mask = torch.unsqueeze(gt_mask, 0)
        gt_mask = gt_mask.float()
        pre_mask = torch.from_numpy(pre_mask_np)
        pre_mask = torch.unsqueeze(pre_mask, 0)
        pre_mask = pre_mask.float()
        # gt_mask.append(gt_mask)
        dice_score = cal_dice(pre_mask, gt_mask, num_class)
        if num_class == 3:
            df = pd.DataFrame({
                'case_name':case_name,
                'left_testis': dice_score[0].item(),
                'right_testis': dice_score[1].item()
            },index=[0]) 
        if num_class == 2:
            df = pd.DataFrame({
                'case_name': case_name,
                'tumor':dice_score[0].item()
            },index=[0])
        dice_score_record = dice_score_record.append(df)
	dice_score_record.to_csv(results_csv, index=None)

input_path = '/home/xxx/06_datalist/dice.txt'
results_csv = '/home/xxx/06_datalist/dice.csv'
val(input_path, results_csv)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值