1. 数据预处理
# preprocessing.py
import os
import numpy as np
import cv2
import pydicom
from matplotlib import pyplot as plt
def windowing(img, window_width, window_center):
min_windows = float(window_center) - 0.5 * float(window_width)
new_img = (img - min_windows) / float(window_width)
new_img[new_img < 0] = 0
new_img[new_img > 1] = 1
return (new_img * 255).astype('uint8')
# 直方图均衡化
def clahe_equalization(img):
assert(len(img.shape) == 3)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
img_equalized = np.empty(img.shape)
for i in range(len(img)):
img_equalized[i, :, :] = clahe.apply(np.array(img[i, :, :], dtype=np.uint8))
return img_equalized/255
# dicom像素值可能受设备影响,但hu值与设备无关
def get_pixels_hu(scans):
image = np.stack([s.pixel_array for s in scans])
image = image.astype(np.int16)
# dicom原始数据无效区域填充了-2000
image[image == -2000] = 0
intercept = scans[0].RescaleIntercept
slope = scans[0].RescaleSlope
if slope != 1:
image = slope * image.astype(np.float64)
image = image.astype(np.int16)
image += np.int16(intercept)
return np.array(image, dtype=np.int16)
# 设置文件保存路径
patient_save_path = './dataset/3Dircadb/patient'
liver_save_path = './dataset/3Dircadb/mask'
# 读取并处理数据
def processimage(start, end):
for num in range(start, end):
print("正在处理第%d号病人" % num)
data_path = './dataset/3Dircadb/3Dircadb1.%d/PATIENT_DICOM' % num
# 读取CT图像
image_slices = [pydicom.dcmread(os.path.join(data_path, file_name)) for file_name in os.listdir(data_path)]
# 排序
image_slices.sort(key=lambda x: x.InstanceNumber)
# 提取像素值并转化为hu值
image_array = get_pixels_hu(image_slices)
# CT图增强-windowing
img_ct = windowing(image_array, 500, 150)
# 直方图均衡化
img_clahe = clahe_equalization(img_ct)
# liver mask图处理
mask_path = './dataset/3Dircadb/3Dircadb1.%d/MASKS_DICOM' % num
liver_paths = [os.path.join(mask_path, i) for i in os.listdir(mask_path) if i == "liver"]
# 重新排序
liver_paths.sort()
# 提取所有肿瘤数据
for liver_path in liver_paths:
liver_slices = [pydicom.dcmread(os.path.join(liver_path, file_name)) for file_name in
os.listdir(liver_path)]
# 重新对肿瘤片段图排序
liver_slices.sort(key=lambda x: x.InstanceNumber)
# 提取像素值
liver_array = np.array([i.pixel_array for i in liver_slices])
# 没有肿瘤的掩模图全为黑色,对应像素全为0
index = [i.sum() > 0 for i in liver_array]
img_liver = liver_array[index]
# 对增强后的CT图提取肿瘤
img_patient = img_clahe[index]
# 保存所有肿瘤数据
for i in range(len(img_patient)):
plt.imsave(os.path.join(patient_save_path, f'{num}_{i}.png'), img_patient[i], cmap='gray') # 保存CT图
plt.imsave(os.path.join(liver_save_path, f'{num}_{i}.png'), img_liver[i], cmap='gray') # 保存肿瘤掩模图
return img_patient, img_liver
train_image_patient, train_image_tumor = processimage(1, 15)
test_image_patient, test_image_tumor = processimage(16, 20)
2. 构建数据集
# dataloader.py
import torch
from PIL import Image
from torchvision import transforms
from torch.utils import data
import glob
import random
patient = glob.glob('./dataset/3Dircadb/patient/*.png')
mask = glob.glob('./dataset/3Dircadb/mask/*.png')
# 打乱顺序
index = [i for i in range(len(patient))]
random.shuffle(index)
image = []
label = []
for i in index:
image.append(patient[i])
label.append(mask[i])
train_images = image[0:500]
train_labels = label[0:500]
test_images = image[501:]
test_labels = label[501:]
transform = transforms.Compose([
transforms.ToTensor()
])
class Portrait_dataset(data.Dataset):
def __init__(self, img_paths, anno_paths):
self.imgs = img_paths
self.annos = anno_paths
def __getitem__(self, index):
img = self.imgs[index]
anno = self.annos[index]
pil_img = Image.open(img)
# 转为三通道彩色图
pil_img = pil_img.convert("RGB")
img_tensor = transform(pil_img)
pil_anno = Image.open(anno)
# 转为单通道灰度图
pil_anno = pil_anno.convert("L")
anno_tensor = transform(pil_anno)
anno_tensor = torch.squeeze(anno_tensor).type(torch.long)
return img_tensor, anno_tensor
def __len__(self):
return len(self.imgs)
BATCH_SIZE = 2
train_set = Portrait_dataset(train_images, train_labels)
test_set = Portrait_dataset(test_images, test_labels)
trainloader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False)
testloader = data.DataLoader(test_set, batch_size=BATCH_SIZE)
3. 构建Unet model
# model.py
from torch import nn
import torch
class downSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(downSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x, is_pool=True):
if is_pool:
x = self.pool(x)
x = self.conv_relu(x)
return x
class upSample(nn.Module):
def __init__(self, channels):
super(upSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(2 * channels, channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.upConv_relu = nn.Sequential(
nn.ConvTranspose2d(channels, channels // 2,
kernel_size=3, stride=2,
padding=1, output_padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv_relu(x)
x = self.upConv_relu(x)
return x
class Unet(nn.Module):
def __init__(self):
super(Unet, self).__init__()
self.down1 = downSample(3, 64)
self.down2 = downSample(64, 128)
self.down3 = downSample(128, 256)
self.down4 = downSample(256, 512)
self.down5 = downSample(512, 1024)
self.up = nn.Sequential(
nn.ConvTranspose2d(1024, 512,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.ReLU(inplace=True)
)
self.up1 = upSample(512)
self.up2 = upSample(256)
self.up3 = upSample(128)
self.conv_2 = downSample(128, 64) # 最后两层卷积
self.last = nn.Conv2d(64, 2, kernel_size=1) # 输出层为2分类
def forward(self, x):
x1 = self.down1(x, is_pool=False)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
x6 = self.up(x5)
x6 = torch.cat([x4, x6], dim=1) # 32*32*1024
x7 = self.up1(x6) # 64*64*256)
x7 = torch.cat([x3, x7], dim=1) # 64*64*512
x8 = self.up2(x7) # 128*128*128
x8 = torch.cat([x2, x8], dim=1) # 128*128*256
x9 = self.up3(x8) # 256*256*64
x9 = torch.cat([x1, x9], dim=1) # 256*256*128
x10 = self.conv_2(x9, is_pool=False) # 256*256*64
result = self.last(x10) # 256*256*3
return result
4. 训练模型
import model as md
import torch
from torch import nn
from torch.optim import lr_scheduler
from dataloader import trainloader as train_dl, testloader as test_dl
model = md.Unet()
if torch.cuda.is_available():
model.to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
loss_fn = nn.CrossEntropyLoss()
def fit(epoch, model, trainloader, testloader):
correct = 0
total = 0
running_loss = 0
model.train()
for x, y in trainloader:
if torch.cuda.is_available():
x, y = x.to('cuda'), y.to('cuda') # x.shape = [2, 3, 512, 512] y.shape = [2, 512, 512]
y_pred = model(x) # y_pred.shape = [2, 2, 512, 512]
optimizer.zero_grad()
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
with torch.no_grad():
y_pred = torch.argmax(y_pred, dim=1)
correct += (y_pred == y).sum().item()
total += y.size(0)
running_loss += loss.item()
exp_lr_scheduler.step()
epoch_loss = running_loss / len(trainloader)
epoch_acc = correct / (total*512*512)
test_correct = 0
test_total = 0
test_running_loss = 0
model.eval()
with torch.no_grad():
for x, y in testloader:
if torch.cuda.is_available():
x, y = x.to('cuda'), y.to('cuda')
y_pred = model(x)
loss = loss_fn(y_pred, y)
y_pred = torch.argmax(y_pred, dim=1)
test_correct += (y_pred == y).sum().item()
test_total += y.size(0)
test_running_loss += loss.item()
epoch_test_loss = test_running_loss / len(testloader)
epoch_test_acc = test_correct / (test_total*512*512)
print('epoch: ', epoch,
'loss: ', round(epoch_loss, 3),
'accuracy:', round(epoch_acc, 3),
'test_loss: ', round(epoch_test_loss, 3),
'test_accuracy:', round(epoch_test_acc, 3)
)
return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
model,
train_dl,
test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
PATH = 'unet_model.pth'
torch.save(model.state_dict(), PATH)
5. 测试模型
from model import Unet
import torch
from dataloader import testloader
from matplotlib import pyplot as plt
import numpy as np
PATH = 'unet_model.pth'
my_model = Unet()
my_model.load_state_dict(torch.load(PATH))
my_model.eval()
torch.no_grad()
x, y = next(iter(testloader))
pred_y = my_model(x)
pred_y = torch.argmax(pred_y, dim=1)
print(np.unique(y))
print(np.unique(pred_y))
print(pred_y.shape)
plt.figure(figsize=(10, 10))
num = 2
for i in range(num):
plt.subplot(num, 3, i*3+1)
plt.imshow(x[i].permute(1, 2, 0).cpu().numpy())
plt.subplot(num, 3, i*3+2)
plt.imshow(y[i].cpu().numpy())
plt.subplot(num, 3, i*3+3)
plt.imshow(pred_y[i].detach().numpy())
plt.show(block=True)
结果