swin_unet代碼

import numpy as np
import torch
from medpy import metric
from scipy.ndimage import zoom
import torch.nn as nn
import SimpleITK as sitk
from PIL import Image
import copy
 
# 定義 Dice Loss 的類別
class DiceLoss(nn.Module):
    def __init__(self, n_classes):
        super(DiceLoss, self).__init__()
        self.n_classes = n_classes
 
    # 將目標張量轉換為 one-hot 編碼
    def _one_hot_encoder(self, input_tensor):
        tensor_list = []
        for i in range(self.n_classes):
            temp_prob = input_tensor == i  # 生成每個類別的布林掩膜
            tensor_list.append(temp_prob.unsqueeze(1))
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()
 
    # 計算單一類別的 Dice Loss
    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-5  # 平滑因子,避免分母為零
        intersect = torch.sum(score * target)  # 預測與目標的交集
        y_sum = torch.sum(target * target)  # 目標的平方和
        z_sum = torch.sum(score * score)  # 預測的平方和
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss
 
    # 前向傳播計算損失
    def forward(self, inputs, target, weight=None, softmax=False):
        if softmax:
            inputs = torch.softmax(inputs, dim=1)  # 對輸入進行 softmax 處理
        target = self._one_hot_encoder(target)  # 將目標轉換為 one-hot 編碼
        if weight is None:
            weight = [1] * self.n_classes  # 如果沒有權重,則對每個類別賦予相同權重
        assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
        class_wise_dice = []
        loss = 0.0
        for i in range(0, self.n_classes):
            dice = self._dice_loss(inputs[:, i], target[:, i])
            class_wise_dice.append(1.0 - dice.item())
            loss += dice * weight[i]
        return loss / self.n_classes  # 返回平均損失
 
# 計算每個案例的評估指標(Dice 和 HD95)
def calculate_metric_percase(pred, gt):
    pred[pred > 0] = 1  # 將預測二值化
    gt[gt > 0] = 1  # 將標籤二值化
    if pred.sum() > 0 and gt.sum() > 0:
        dice = metric.binary.dc(pred, gt)  # 計算 Dice Coefficient
        hd95 = metric.binary.hd95(pred, gt)  # 計算 95% Hausdorff Distance
        return dice, hd95
    elif pred.sum() > 0 and gt.sum() == 0:
        return 1, 0  # 如果預測有值但標籤為空
    else:
        return 0, 0  # 如果預測和標籤都為空
 
# 測試單一體積影像
def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
    image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
    _, x, y = image.shape
 
    # 將影像縮放到網路輸入大小 (224x224)
    if x != patch_size[0] or y != patch_size[1]:
        image = zoom(image, (1, patch_size[0] / x, patch_size[1] / y), order=3)
    input = torch.from_numpy(image).unsqueeze(0).float().cuda()
    net.eval()
    with torch.no_grad():
        out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
        out = out.cpu().detach().numpy()
        # 將預測結果縮放回原始影像大小
        if x != patch_size[0] or y != patch_size[1]:
            prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
        else:
            prediction = out
    metric_list = []
    for i in range(1, classes):
        metric_list.append(calculate_metric_percase(prediction == i, label == i))
 
    # 將不同類別區域以彩色顯示
    if test_save_path is not None:
        a1 = copy.deepcopy(prediction)
        a2 = copy.deepcopy(prediction)
        a3 = copy.deepcopy(prediction)
        # r 通道
        a1[a1 == 1] = 0
        # g 通道
        a2[a2 == 1] = 255
        # b 通道
        a3[a3 == 1] = 0
        a1 = Image.fromarray(np.uint8(a1)).convert('L')
        a2 = Image.fromarray(np.uint8(a2)).convert('L')
        a3 = Image.fromarray(np.uint8(a3)).convert('L')
        prediction = Image.merge('RGB', [a1, a2, a3])
        prediction.save(test_save_path + '/' + case + '.png')
 
    return metric_list
import os
import random
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset

# 隨機旋轉和翻轉影像及標籤
def random_rot_flip(image, label):
    k = np.random.randint(0, 4)  # 隨機旋轉 0、90、180 或 270 度
    image = np.rot90(image, k)
    label = np.rot90(label, k)
    axis = np.random.randint(0, 2)  # 隨機選擇翻轉軸(0 為垂直翻轉,1 為水平翻轉)
    image = np.flip(image, axis=axis).copy()
    label = np.flip(label, axis=axis).copy()
    return image, label

# 隨機旋轉影像及標籤
def random_rotate(image, label):
    angle = np.random.randint(-20, 20)  # 隨機選擇旋轉角度(-20 到 20 度)
    image = ndimage.rotate(image, angle, order=0, reshape=False)  # 影像旋轉
    label = ndimage.rotate(label, angle, order=0, reshape=False)  # 標籤旋轉
    return image, label

# 定義隨機生成器類別
class RandomGenerator(object):
    def __init__(self, output_size):
        self.output_size = output_size  # 設定輸出大小

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # 隨機執行旋轉或翻轉
        if random.random() > 0.5:
            image, label = random_rot_flip(image, label)
        elif random.random() > 0.5:
            image, label = random_rotate(image, label)

        x, y = image.shape  # 獲取影像的寬和高
        # 如果影像大小與輸出大小不同,進行縮放
        if x != self.output_size[0] or y != self.output_size[1]:
            image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3)
            label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
        
        # 將影像轉為張量並添加一個維度
        image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
        label = torch.from_numpy(label.astype(np.float32))

        sample = {'image': image, 'label': label.long()}  # 將影像和標籤打包
        return sample

# 定義 Synapse 資料集類別
class Synapse_dataset(Dataset):
    def __init__(self, base_dir, list_dir, split, transform=None):
        self.transform = transform  # 是否應用數據增強
        self.split = split  # 資料集分割類型(訓練或測試)
        self.sample_list = open(os.path.join(list_dir, self.split + '.txt')).readlines()  # 讀取對應的樣本列表
        self.data_dir = base_dir  # 資料集的根目錄

    def __len__(self):
        return len(self.sample_list)  # 返回資料集的樣本數

    def __getitem__(self, idx):
        if self.split == "train":  # 如果是訓練集
            slice_name = self.sample_list[idx].strip('\n')  # 獲取樣本名稱
            data_path = os.path.join(self.data_dir, slice_name + '.npz')  # 構建資料路徑
            data = np.load(data_path)  # 加載資料
            image, label = data['image'], data['label']
            sample = {'image': image, 'label': label}  # 將影像和標籤打包
            if self.transform:
                sample = self.transform(sample)  # 應用數據增強
            sample['case_name'] = self.sample_list[idx].strip('\n')  # 添加樣本名稱
            return sample            
        else:  # 如果是測試集
            slice_name = self.sample_list[idx].strip('\n')  # 獲取樣本名稱
            data_path = os.path.join(self.data_dir, slice_name + '.npz')  # 構建資料路徑
            try:
                data = np.load(data_path)  # 加載資料
                image, label = data['image'], data['label']
            except (FileNotFoundError, KeyError) as e:
                print(f"Error loading {data_path}: {e}")
                return None  # 如果發生錯誤,返回 None
            # 測試時將影像轉為張量並添加一個維度
            image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
            label = torch.from_numpy(label.astype(np.float32))
            return {"image": image, "label": label, "case_name": slice_name}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值