3Dircadb数据集使用unet分割肝脏

1. 数据预处理

# preprocessing.py
import os
import numpy as np
import cv2
import pydicom
from matplotlib import pyplot as plt


def windowing(img, window_width, window_center):
    min_windows = float(window_center) - 0.5 * float(window_width)
    new_img = (img - min_windows) / float(window_width)
    new_img[new_img < 0] = 0
    new_img[new_img > 1] = 1
    return (new_img * 255).astype('uint8')


# 直方图均衡化
def clahe_equalization(img):
    assert(len(img.shape) == 3)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    img_equalized = np.empty(img.shape)
    for i in range(len(img)):
        img_equalized[i, :, :] = clahe.apply(np.array(img[i, :, :], dtype=np.uint8))
    return img_equalized/255

# dicom像素值可能受设备影响,但hu值与设备无关
def get_pixels_hu(scans):
    image = np.stack([s.pixel_array for s in scans])
    image = image.astype(np.int16)
    # dicom原始数据无效区域填充了-2000
    image[image == -2000] = 0
    intercept = scans[0].RescaleIntercept
    slope = scans[0].RescaleSlope
    if slope != 1:
        image = slope * image.astype(np.float64)
        image = image.astype(np.int16)

    image += np.int16(intercept)
    return np.array(image, dtype=np.int16)


# 设置文件保存路径
patient_save_path = './dataset/3Dircadb/patient'
liver_save_path = './dataset/3Dircadb/mask'


# 读取并处理数据
def processimage(start, end):
    for num in range(start, end):
        print("正在处理第%d号病人" % num)
        data_path = './dataset/3Dircadb/3Dircadb1.%d/PATIENT_DICOM' % num
        # 读取CT图像
        image_slices = [pydicom.dcmread(os.path.join(data_path, file_name)) for file_name in os.listdir(data_path)]
        # 排序
        image_slices.sort(key=lambda x: x.InstanceNumber)
        # 提取像素值并转化为hu值
        image_array = get_pixels_hu(image_slices)
        # CT图增强-windowing
        img_ct = windowing(image_array, 500, 150)
        # 直方图均衡化
        img_clahe = clahe_equalization(img_ct)

        # liver mask图处理
        mask_path = './dataset/3Dircadb/3Dircadb1.%d/MASKS_DICOM' % num
        liver_paths = [os.path.join(mask_path, i) for i in os.listdir(mask_path) if i == "liver"]
        # 重新排序
        liver_paths.sort()
        # 提取所有肿瘤数据
        for liver_path in liver_paths:
            liver_slices = [pydicom.dcmread(os.path.join(liver_path, file_name)) for file_name in
                            os.listdir(liver_path)]
            # 重新对肿瘤片段图排序
            liver_slices.sort(key=lambda x: x.InstanceNumber)
            # 提取像素值
            liver_array = np.array([i.pixel_array for i in liver_slices])
            # 没有肿瘤的掩模图全为黑色,对应像素全为0
            index = [i.sum() > 0 for i in liver_array]
            img_liver = liver_array[index]
            # 对增强后的CT图提取肿瘤
            img_patient = img_clahe[index]
            # 保存所有肿瘤数据
            for i in range(len(img_patient)):
                plt.imsave(os.path.join(patient_save_path, f'{num}_{i}.png'), img_patient[i], cmap='gray')  # 保存CT图
                plt.imsave(os.path.join(liver_save_path, f'{num}_{i}.png'), img_liver[i], cmap='gray')  # 保存肿瘤掩模图
    return img_patient, img_liver


train_image_patient, train_image_tumor = processimage(1, 15)
test_image_patient, test_image_tumor = processimage(16, 20)

2. 构建数据集

# dataloader.py
import torch
from PIL import Image
from torchvision import transforms
from torch.utils import data
import glob
import random

patient = glob.glob('./dataset/3Dircadb/patient/*.png')
mask = glob.glob('./dataset/3Dircadb/mask/*.png')
# 打乱顺序
index = [i for i in range(len(patient))]
random.shuffle(index)
image = []
label = []
for i in index:
    image.append(patient[i])
    label.append(mask[i])


train_images = image[0:500]
train_labels = label[0:500]
test_images = image[501:]
test_labels = label[501:]


transform = transforms.Compose([
    transforms.ToTensor()
])


class Portrait_dataset(data.Dataset):
    def __init__(self, img_paths, anno_paths):
        self.imgs = img_paths
        self.annos = anno_paths

    def __getitem__(self, index):
        img = self.imgs[index]
        anno = self.annos[index]

        pil_img = Image.open(img)
        # 转为三通道彩色图
        pil_img = pil_img.convert("RGB")
        img_tensor = transform(pil_img)

        pil_anno = Image.open(anno)
        # 转为单通道灰度图
        pil_anno = pil_anno.convert("L")
        anno_tensor = transform(pil_anno)
        anno_tensor = torch.squeeze(anno_tensor).type(torch.long)

        return img_tensor, anno_tensor

    def __len__(self):
        return len(self.imgs)


BATCH_SIZE = 2
train_set = Portrait_dataset(train_images, train_labels)
test_set = Portrait_dataset(test_images, test_labels)
trainloader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False)
testloader = data.DataLoader(test_set, batch_size=BATCH_SIZE)

3. 构建Unet model

# model.py
from torch import nn
import torch


class downSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(downSample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(kernel_size=2)

    def forward(self, x, is_pool=True):
        if is_pool:
            x = self.pool(x)
        x = self.conv_relu(x)
        return x


class upSample(nn.Module):
    def __init__(self, channels):
        super(upSample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(2 * channels, channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.upConv_relu = nn.Sequential(
            nn.ConvTranspose2d(channels, channels // 2,
                               kernel_size=3, stride=2,
                               padding=1, output_padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv_relu(x)
        x = self.upConv_relu(x)
        return x


class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.down1 = downSample(3, 64)
        self.down2 = downSample(64, 128)
        self.down3 = downSample(128, 256)
        self.down4 = downSample(256, 512)
        self.down5 = downSample(512, 1024)
        self.up = nn.Sequential(
            nn.ConvTranspose2d(1024, 512,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.ReLU(inplace=True)
        )
        self.up1 = upSample(512)
        self.up2 = upSample(256)
        self.up3 = upSample(128)
        self.conv_2 = downSample(128, 64)  # 最后两层卷积
        self.last = nn.Conv2d(64, 2, kernel_size=1)  # 输出层为2分类

    def forward(self, x):
        x1 = self.down1(x, is_pool=False)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)

        x6 = self.up(x5)

        x6 = torch.cat([x4, x6], dim=1)  # 32*32*1024
        x7 = self.up1(x6)  # 64*64*256)
        x7 = torch.cat([x3, x7], dim=1)  # 64*64*512
        x8 = self.up2(x7)  # 128*128*128
        x8 = torch.cat([x2, x8], dim=1)  # 128*128*256
        x9 = self.up3(x8)  # 256*256*64
        x9 = torch.cat([x1, x9], dim=1)  # 256*256*128

        x10 = self.conv_2(x9, is_pool=False)  # 256*256*64

        result = self.last(x10)  # 256*256*3
        return result

4. 训练模型

import model as md
import torch
from torch import nn
from torch.optim import lr_scheduler
from dataloader import trainloader as train_dl, testloader as test_dl


model = md.Unet()
if torch.cuda.is_available():
    model.to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
loss_fn = nn.CrossEntropyLoss()


def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0

    model.train()
    for x, y in trainloader:
        if torch.cuda.is_available():
            x, y = x.to('cuda'), y.to('cuda')   # x.shape = [2, 3, 512, 512] y.shape = [2, 512, 512]
        y_pred = model(x)  # y_pred.shape = [2, 2, 512, 512]
        optimizer.zero_grad()
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
    exp_lr_scheduler.step()
    epoch_loss = running_loss / len(trainloader)
    epoch_acc = correct / (total*512*512)

    test_correct = 0
    test_total = 0
    test_running_loss = 0
    model.eval()
    with torch.no_grad():
        for x, y in testloader:
            if torch.cuda.is_available():
                x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()

    epoch_test_loss = test_running_loss / len(testloader)
    epoch_test_acc = test_correct / (test_total*512*512)

    print('epoch: ', epoch,
          'loss: ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3)
          )

    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc


epochs = 10

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                 model,
                                                                 train_dl,
                                                                 test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

PATH = 'unet_model.pth'
torch.save(model.state_dict(), PATH)

5. 测试模型

from model import Unet
import torch
from dataloader import testloader
from matplotlib import pyplot as plt
import numpy as np


PATH = 'unet_model.pth'
my_model = Unet()
my_model.load_state_dict(torch.load(PATH))
my_model.eval()
torch.no_grad()

x, y = next(iter(testloader))
pred_y = my_model(x)
pred_y = torch.argmax(pred_y, dim=1)
print(np.unique(y))
print(np.unique(pred_y))
print(pred_y.shape)

plt.figure(figsize=(10, 10))
num = 2
for i in range(num):
    plt.subplot(num, 3, i*3+1)
    plt.imshow(x[i].permute(1, 2, 0).cpu().numpy())
    plt.subplot(num, 3, i*3+2)
    plt.imshow(y[i].cpu().numpy())
    plt.subplot(num, 3, i*3+3)
    plt.imshow(pred_y[i].detach().numpy())
plt.show(block=True)

结果

 

  • 4
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值