import numpy as np
import torch
from medpy import metric
from scipy.ndimage import zoom
import torch.nn as nn
import SimpleITK as sitk
from PIL import Image
import copy
# 定義 Dice Loss 的類別
class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes
# 將目標張量轉換為 one-hot 編碼
def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i # 生成每個類別的布林掩膜
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
# 計算單一類別的 Dice Loss
def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-5 # 平滑因子,避免分母為零
intersect = torch.sum(score * target) # 預測與目標的交集
y_sum = torch.sum(target * target) # 目標的平方和
z_sum = torch.sum(score * score) # 預測的平方和
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
# 前向傳播計算損失
def forward(self, inputs, target, weight=None, softmax=False):
if softmax:
inputs = torch.softmax(inputs, dim=1) # 對輸入進行 softmax 處理
target = self._one_hot_encoder(target) # 將目標轉換為 one-hot 編碼
if weight is None:
weight = [1] * self.n_classes # 如果沒有權重,則對每個類別賦予相同權重
assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes # 返回平均損失
# 計算每個案例的評估指標(Dice 和 HD95)
def calculate_metric_percase(pred, gt):
pred[pred > 0] = 1 # 將預測二值化
gt[gt > 0] = 1 # 將標籤二值化
if pred.sum() > 0 and gt.sum() > 0:
dice = metric.binary.dc(pred, gt) # 計算 Dice Coefficient
hd95 = metric.binary.hd95(pred, gt) # 計算 95% Hausdorff Distance
return dice, hd95
elif pred.sum() > 0 and gt.sum() == 0:
return 1, 0 # 如果預測有值但標籤為空
else:
return 0, 0 # 如果預測和標籤都為空
# 測試單一體積影像
def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
_, x, y = image.shape
# 將影像縮放到網路輸入大小 (224x224)
if x != patch_size[0] or y != patch_size[1]:
image = zoom(image, (1, patch_size[0] / x, patch_size[1] / y), order=3)
input = torch.from_numpy(image).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
out = out.cpu().detach().numpy()
# 將預測結果縮放回原始影像大小
if x != patch_size[0] or y != patch_size[1]:
prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
else:
prediction = out
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_metric_percase(prediction == i, label == i))
# 將不同類別區域以彩色顯示
if test_save_path is not None:
a1 = copy.deepcopy(prediction)
a2 = copy.deepcopy(prediction)
a3 = copy.deepcopy(prediction)
# r 通道
a1[a1 == 1] = 0
# g 通道
a2[a2 == 1] = 255
# b 通道
a3[a3 == 1] = 0
a1 = Image.fromarray(np.uint8(a1)).convert('L')
a2 = Image.fromarray(np.uint8(a2)).convert('L')
a3 = Image.fromarray(np.uint8(a3)).convert('L')
prediction = Image.merge('RGB', [a1, a2, a3])
prediction.save(test_save_path + '/' + case + '.png')
return metric_list
import os
import random
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
# 隨機旋轉和翻轉影像及標籤
def random_rot_flip(image, label):
k = np.random.randint(0, 4) # 隨機旋轉 0、90、180 或 270 度
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2) # 隨機選擇翻轉軸(0 為垂直翻轉,1 為水平翻轉)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return image, label
# 隨機旋轉影像及標籤
def random_rotate(image, label):
angle = np.random.randint(-20, 20) # 隨機選擇旋轉角度(-20 到 20 度)
image = ndimage.rotate(image, angle, order=0, reshape=False) # 影像旋轉
label = ndimage.rotate(label, angle, order=0, reshape=False) # 標籤旋轉
return image, label
# 定義隨機生成器類別
class RandomGenerator(object):
def __init__(self, output_size):
self.output_size = output_size # 設定輸出大小
def __call__(self, sample):
image, label = sample['image'], sample['label']
# 隨機執行旋轉或翻轉
if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
x, y = image.shape # 獲取影像的寬和高
# 如果影像大小與輸出大小不同,進行縮放
if x != self.output_size[0] or y != self.output_size[1]:
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3)
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
# 將影像轉為張量並添加一個維度
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label.long()} # 將影像和標籤打包
return sample
# 定義 Synapse 資料集類別
class Synapse_dataset(Dataset):
def __init__(self, base_dir, list_dir, split, transform=None):
self.transform = transform # 是否應用數據增強
self.split = split # 資料集分割類型(訓練或測試)
self.sample_list = open(os.path.join(list_dir, self.split + '.txt')).readlines() # 讀取對應的樣本列表
self.data_dir = base_dir # 資料集的根目錄
def __len__(self):
return len(self.sample_list) # 返回資料集的樣本數
def __getitem__(self, idx):
if self.split == "train": # 如果是訓練集
slice_name = self.sample_list[idx].strip('\n') # 獲取樣本名稱
data_path = os.path.join(self.data_dir, slice_name + '.npz') # 構建資料路徑
data = np.load(data_path) # 加載資料
image, label = data['image'], data['label']
sample = {'image': image, 'label': label} # 將影像和標籤打包
if self.transform:
sample = self.transform(sample) # 應用數據增強
sample['case_name'] = self.sample_list[idx].strip('\n') # 添加樣本名稱
return sample
else: # 如果是測試集
slice_name = self.sample_list[idx].strip('\n') # 獲取樣本名稱
data_path = os.path.join(self.data_dir, slice_name + '.npz') # 構建資料路徑
try:
data = np.load(data_path) # 加載資料
image, label = data['image'], data['label']
except (FileNotFoundError, KeyError) as e:
print(f"Error loading {data_path}: {e}")
return None # 如果發生錯誤,返回 None
# 測試時將影像轉為張量並添加一個維度
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
label = torch.from_numpy(label.astype(np.float32))
return {"image": image, "label": label, "case_name": slice_name}