在医学图像分割 基于深度学习的肝脏肿瘤分割 实战(一)中,实现了对肝脏的分割,但是后续在使用相同的处理方法与模型进行肿瘤分割的时候,遇到了两次问题。
第一次,网络的dice系数,训练集上一直只能达到40%左右,测试集上只有10%左右,而且结果明显不对。我认为是数据层面有问题,随机又看了很多原始图片,发现很多肿瘤肉眼都无法分辨,于是进行了实验验证猜想,实验写在了博客:医学图像预处理(五) 器官与病灶的直方图,即验证很多病人的肝脏和肿瘤灰度值(更准确地说是hu值)几乎是重叠的。
实验验证了猜想,开始重新做实验。我于是从3Dircadb数据库换成了LiTS2017数据库,后者的数据量更大。然后根据病人的肝脏与肿瘤直方图分布,手动将130位病人划分成了三个等级:level1:肝脏与肿瘤对比度最大。level3:肝脏与肿瘤的对比度几乎没有。训练数据只选择level1&level2。
当然,数据预处理操作和ROI操作和之前一样,是都有的。
但是这回做实验,一样出了问题。即第二次问题。
第二次:先是dice系数在训练集上可以到10%,可是又会突然回到接近0的值。我百思不得其解,以为是发生了梯度消失或者梯度爆炸的问题,还差了很多资料看如何解决。
可是,后来想着想着,发现一个奇怪又严肃的问题,就是ROI操作会将真实肝脏分割结果与原图做“与”操作,这样,非感兴趣的区域就会变黑,就可以让网络集中注意力在肝脏内部的区域,可是,肝脏内部的肿瘤也是黑色(一般灰度值低于肝脏)的呀??为什么外面变成黑色就是“非感兴趣区域”,而肝脏里的就得是“目标”呢? 想了想之前肝脏分割的图片,发现做窗口值等的操作,也是将非目标区域变黑,然后可以突出肝脏(肝脏变成灰白色,当然,还有其他器官更是白色)。
于是,我做了一个大胆的实验,将肝脏变成灰色,肿瘤变成白色,肝脏外的区域为黑色。(进行了颜色翻转)
这一次实验,训练集上,dice系数90%左右,测试集上70%左右。虽然不够好,但是至少证明了猜想是可行的。也反映了,数据是王道:种瓜得瓜,种豆得豆。
下面是实验的代码,第一部分是准备数据集的过程(预处理后写成h5文件),这部分在本地环境进行。第二部分是模型构建和训练过程,这部分在服务器进行(ubuntu16.04, tensorflow-gpu)
第一部分:
(注:有关h5文件读写的工具类也放在了博客里)
# -*- coding: utf-8 -*-
"""
根据LITS_check.py,观察结果
根据肝脏与肿瘤的对比度,将病人分成 3 level
1 level:对比度最高(随机选出两个作为validation集)
2 level: 对比度中等(随机选出两个作为validation集)
3 level:对比度最低
"""
# theshold = 1e-3, total=755
# 81,125作为测试集
level_1 = [0,1,22,23,25,26,27,31,37,46,49,50,55,57,58,59,61,62,
63,64,66,78,79,82,83,90,92,95,99,109,112,124]
#level_1 = [63,64,66,78,79,81,82,83]
# theshold = 1e-3, total= 1345
# 11,110作为测试集
level_2 = [2,7,8,9,10,12,14,15,17,28,35,40,42,
53,56,69,76,93,96,101,111,113,117]
level12 = level_1 + level_2
level12.sort()
test_list = [11,81,110,125]
a = [i for i in range(130)]
level_3 =list(set(a)-set(level_1)-set(level_2))
# sort方法直接改变原列表,无返回值
level_3.sort()
"""
将level_1的其余图片观察,确定是否对比度高
观察后确定窗口值为:[-50,200]
"""
onServer = False
if onServer:
niiSegPath = './LITS17/seg/'
niiImagePath = './LITS17/ct/'
else:
niiSegPath = '~/Documents/LITS17/seg/'
niiImagePath = '~/Documents/LITS17/ct/'
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
def getRangeImageDepth(image):
z = np.any(image, axis=(1,2)) # z.shape:(depth,)
#print("all index:",np.where(z)[0])
if len(np.where(z)[0]) >0:
startposition,endposition = np.where(z)[0][[0,-1]]
else:
startposition = endposition = 0
return startposition, endposition
def sample_stack(stack, name="images.png", rows=4, cols=2, start_with=0, show_every=1):
fig,ax = plt.subplots(rows,cols,figsize=[5*cols,5*rows])
if rows==1 or cols==1 :
nums = rows*cols
for i in range(nums):
ind = start_with + i*show_every
ax[int(i % nums)].set_title('slice %d' % ind)
ax[int(i % nums)].imshow(stack[ind],cmap='gray')
ax[int(i % nums)].axis('off')
else:
for i in range(rows*cols):
ind = start_with + i*show_every
ax[int(i/cols),int(i % cols)].set_title('slice %d' % ind)
ax[int(i/cols),int(i % cols)].imshow(stack[ind],cmap='gray')
ax[int(i/cols),int(i % cols)].axis('off')
# 这句话一定要在show之前写,否则show函数之后会创建新的空白图
# plt.savefig(name)
plt.show()
"""
工具函数,左边原图,右边真实分割图
"""
def show_src_seg(srcimg, segimg,index, rows=3,start_with=0, show_every=1):
assert srcimg.shape == segimg.shape
rows = srcimg.shape[0]
plan_rows = start_with + rows*show_every - 1
print("rows=%d,planned_rows=%d"%(rows,plan_rows))
rows = plan_rows if (rows > plan_rows) else rows
cols = 2
print("final rows=%d"%rows)
fig,ax = plt.subplots(rows,cols,figsize=[5*cols,5*rows])
for i in range(rows):
ind = start_with + i*show_every
ax[i,0].set_title('src slice %d' % ind)
ax[i,0].imshow(srcimg[ind],cmap='gray')
ax[i,0].axis('off')
ax[i,1].set_title('truth seg slice %d' % ind)
ax[i,1].imshow(segimg[ind],cmap='gray')
ax[i,1].axis('off')
# 这句话一定要在show之前写,否则show函数之后会创建新的空白图
name = "../LITS/crop/"+str(index)+".png"
plt.savefig(name)
# plt.show()
def transform_ctdata(image, windowWidth, windowCenter, normal=False):
"""
注意,这个函数的self.image一定得是float类型的,否则就无效!
return: trucated image according to window center and window width
"""
minWindow = float(windowCenter) - 0.5*float(windowWidth)
newimg = (image - minWindow) / float(windowWidth)
newimg[newimg < 0] = 0
newimg[newimg > 1] = 1
if not normal:
newimg = (newimg * 255).astype('uint8')
return newimg
import cv2
def clahe_equalized(imgs):
assert (len(imgs.shape)==3) #3D arrays
#create a CLAHE object (Arguments are optional).
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
imgs_equalized = np.empty(imgs.shape)
for i in range(len(imgs)):
imgs_equalized[i,:,:] = clahe.apply(np.array(imgs[i,:,:], dtype = np.uint8))
return imgs_equalized
# 根据肝脏真实分割图,将原始图片进行裁剪为 以肝脏为中心,指定宽、高的图片
def crop_images_func(refer_images, target_images, target_tumors):
maxw=maxh=0
assert refer_images.shape == target_images.shape == target_tumors.shape
crop_images = []
crop_tumors = []
for i in range(refer_images.shape[0]):
# Create figure and axes
# fig,ax = plt.subplots(1)
mask = refer_images[i]
# find coordinates of liver
coor = np.nonzero(mask)
xmin = coor[0][0] # x代表了行
xmax = coor[0][-1]
coor[1].sort() # 直接改变原数组,没有返回值
ymin = coor[1][0]
ymax = coor[1][-1]
width_center = (ymax + ymin) // 2
height_center = (xmax + xmin) // 2
# pre-parameter: height:266, width:334
# 参数的选定:是之前随机运行后,挑出的最大值,然后适当扩大后的结果
height = 280
width = 360
istart = int(height_center - height/2)
#注意逻辑!
if istart < 0:
istart = 0
iend = height
else:
iend = int(istart + height)
if iend > 512:
istart = 512 - height
iend = 512
jstart = int(width_center - width/2)
if jstart < 0:
jstart = 0
jend = width
jend = int(jstart + width)
if jend > 512:
jstart = 512 - width
jend = 512
# print("[%d:%d,%d:%d]"%(istart,iend,jstart,jend))
mask_crop = target_images[i,istart:iend,jstart:jend]
tumors_crop = target_tumors[i,istart:iend,jstart:jend]
# ax.imshow(mask_crop,cmap=plt.cm.gray)
crop_images.append(mask_crop)
crop_tumors.append(tumors_crop)
crop_images = np.asarray(crop_images)
crop_tumors = np.asarray(crop_tumors)
return (crop_images,crop_tumors)
"""
训练数据
第一步:读取数据
第二步:找到具有肿瘤的切片(具有肿瘤的切片一定是肝脏也在的)
第三步:预处理
窗口化、自适应直方图均衡化、归一化、颜色翻转、ROI
第四步:裁剪
第五步:将数据写入文件
"""
# 工具类在博客里有写
from HDF5DatasetWriter import HDF5DatasetWriter
dataset = HDF5DatasetWriter(image_dims=(1967, 280, 360, 1),
mask_dims=(1967, 280, 360, 1),
outputPath="../data_train/LITS_train_tumor_crop.h5")
count = 0
for i in level12:
seg = sitk.ReadImage(niiSegPath+ "segmentation-" + str(i) + ".nii", sitk.sitkUInt8)
segimg = sitk.GetArrayFromImage(seg)
src = sitk.ReadImage(niiImagePath+"volume-" + str(i) + ".nii")
srcimg = sitk.GetArrayFromImage(src)
seg_liver = segimg.copy()
seg_liver[seg_liver>0] = 1
seg_tumorimage = segimg.copy()
seg_tumorimage[segimg == 1] = 0
seg_tumorimage[segimg == 2] = 1
# 只选择ROI区域
srcimg = srcimg * seg_liver
start,end = getRangeImageDepth(seg_tumorimage)
if start==0 and end == 0:
print("continue")
continue
print("start:",start," end:",end)
theshold = 1e-3 # 最小阈值
filter_index = []
for j in range(start, end+1):
if np.mean(seg_tumorimage[j]) > theshold:
filter_index.append(j)
if len(filter_index)<1:
continue
count += len(filter_index)
# print("picked index:",filter_index)
srcimg = srcimg[filter_index]
seg_liver = seg_liver[filter_index]
seg_tumorimage = seg_tumorimage[filter_index]
#
srcimg = transform_ctdata(srcimg, 250,75,normal=False)
srcimg = clahe_equalized(srcimg)
srcimg /= 255.
# 注意,下面这两步顺序一定不能变,否则就不能达到正确的颜色翻转效果了
srcimg = 1- srcimg
# 只选择ROI区域
srcimg = srcimg * seg_liver
crop_images,crop_tumors = crop_images_func(seg_liver,srcimg,seg_tumorimage)
# show_src_seg(crop_images,crop_tumors,index=i)
crop_images = np.expand_dims(crop_images,axis=-1)
crop_tumors = np.expand_dims(crop_tumors,axis=-1)
dataset.add(crop_images,crop_tumors)
print(dataset.close())
dataset = HDF5DatasetWriter(image_dims=(133, 280, 360, 1),
mask_dims=(133, 280, 360, 1),
outputPath="../data_train/LITS_val_tumor_crop.h5")
count = 0
for i in test_list:
seg = sitk.ReadImage(niiSegPath+ "segmentation-" + str(i) + ".nii", sitk.sitkUInt8)
segimg = sitk.GetArrayFromImage(seg)
src = sitk.ReadImage(niiImagePath+"volume-" + str(i) + ".nii")
srcimg = sitk.GetArrayFromImage(src)
seg_liver = segimg.copy()
seg_liver[seg_liver>0] = 1
seg_tumorimage = segimg.copy()
seg_tumorimage[segimg == 1] = 0
seg_tumorimage[segimg == 2] = 1
start,end = getRangeImageDepth(seg_tumorimage)
if start==0 and end == 0:
print("continue")
continue
print("start:",start," end:",end)
theshold = 1e-3 # 最小阈值
filter_index = []
for j in range(start, end+1):
if np.mean(seg_tumorimage[j]) > theshold:
filter_index.append(j)
if len(filter_index)<1:
continue
count += len(filter_index)
# print("picked index:",filter_index)
srcimg = srcimg[filter_index]
seg_liver = seg_liver[filter_index]
seg_tumorimage = seg_tumorimage[filter_index]
srcimg = transform_ctdata(srcimg, 250,75,normal=False)
srcimg = clahe_equalized(srcimg)
srcimg /= 255.
srcimg = 1- srcimg
# 只选择ROI区域
srcimg = srcimg * seg_liver
crop_images,crop_tumors = crop_images_func(seg_liver,srcimg,seg_tumorimage)
show_src_seg(crop_images,crop_tumors,index=i)
crop_images = np.expand_dims(crop_images,axis=-1)
crop_tumors = np.expand_dims(crop_tumors,axis=-1)
dataset.add(crop_images,crop_tumors)
print(dataset.close())
"""
# 测试
from HDF5DatasetGenerator import HDF5DatasetGenerator
outputPath = '../data_train/LITS_train_tumor_crop.h5'
val_outputPath = '../data_train/LITS_val_tumor_crop.h5'
BATCH_SIZE = 8
reader = HDF5DatasetGenerator(dbPath=val_outputPath,batchSize=BATCH_SIZE)
train_iter = reader.generator()
src,seg = train_iter.__next__()
src = np.squeeze(src)
seg = np.squeeze(seg)
sample_stack(src)
sample_stack(seg)
"""
第二部分:
# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
import random
import math
import tensorflow as tf
from HDF5DatasetGenerator import HDF5DatasetGenerator
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose,Cropping2D,ZeroPadding2D
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras import backend as K
from skimage import io
from keras import losses
# Set some parameters
IMG_WIDTH = 360
IMG_HEIGHT = 280
IMG_CHANNELS = 1
TOTAL = 1967 # 总共的训练数据
TOTAL_VAL = 133 # 总共的validation数据
outputPath = '../data_train/LITS_train_tumor_crop.h5' # 训练文件
val_outputPath = '../data_train/LITS_val_tumor_crop.h5'
#checkpoint_path = 'model.ckpt'
BATCH_SIZE = 4
K.set_image_data_format('channels_last')
def dice_coef(y_true, y_pred):
print("in loss function, y_true shape:",y_true.shape)
smooth = 1.
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
# 疑问,不知道除n的操作是否该写?还是说keras会自动取平均
def weighted_binary_cross_entropy_loss(y_true, y_pred):
"""
# 跟标准的结果差不多 0.068760,该结果:0.0685122
print("y_pred shape ",K.int_shape(y_pred))
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
ce = - K.mean(y_true*K.log(K.epsilon()+y_pred) + (1-y_true)*K.log(1-y_pred+K.epsilon()))
return ce
"""
"""
# 跟标准结果一样
b_ce = K.binary_crossentropy(y_true, y_pred)
return b_ce
"""
# 不确定是否正确
# Calculate the binary crossentropy
b_ce = K.binary_crossentropy(y_true, y_pred)
one_weight = K.mean(y_true)
zero_weight = 1 - one_weight
# weight = zero_weight / one_weight
# Apply the weights
weight_vector = y_true * zero_weight + (1. - y_true) * one_weight
weighted_b_ce = weight_vector * b_ce
# Return the mean error
return K.mean(weighted_b_ce)
# 不确定是否正确?
def weighted_dice_loss(y_true, y_pred):
mean = K.mean(y_true)
w_1 = 1/mean**2
w_0 = 1/(1-mean)**2
y_true_f_1 = K.flatten(y_true)
y_pred_f_1 = K.flatten(y_pred)
y_true_f_0 = K.flatten(1-y_true)
y_pred_f_0 = K.flatten(1-y_pred)
intersection_0 = K.sum(y_true_f_0 * y_pred_f_0)
intersection_1 = K.sum(y_true_f_1 * y_pred_f_1)
return -2 * (w_0 * intersection_0 +w_1 * intersection_1)\
/ ((w_0 * (K.sum(y_true_f_0) + K.sum(y_pred_f_0))) \
+ (w_1 * (K.sum(y_true_f_1) + K.sum(y_pred_f_1))))
def get_crop_shape(target, refer):
# width, the 3rd dimension
# print(target.shape)
# print(refer._keras_shape)
cw = (target._keras_shape[2] - refer._keras_shape[2])
assert (cw >= 0)
if cw % 2 != 0:
cw1, cw2 = int(cw/2), int(cw/2) + 1
else:
cw1, cw2 = int(cw/2), int(cw/2)
# height, the 2nd dimension
ch = (target._keras_shape[1] - refer._keras_shape[1])
assert (ch >= 0)
if ch % 2 != 0:
ch1, ch2 = int(ch/2), int(ch/2) + 1
else:
ch1, ch2 = int(ch/2), int(ch/2)
return (ch1, ch2), (cw1, cw2)
def get_unet():
inputs = Input((IMG_HEIGHT, IMG_WIDTH , 1))
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
up_conv5 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5)
ch, cw = get_crop_shape(conv4, up_conv5)
# print("ch,cw",ch,cw)
#
up_conv5 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv5)
up6 = concatenate([up_conv5, conv4], axis=3)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
up_conv6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6)
ch, cw = get_crop_shape(conv3, up_conv6)
up_conv6 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv6)
#
up7 = concatenate([up_conv6, conv3], axis=3)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
up_conv7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7)
ch, cw = get_crop_shape(conv2, up_conv7)
up_conv7 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv7)
up8 = concatenate([up_conv7, conv2], axis=3)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
up_conv8 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
ch, cw = get_crop_shape(conv1, up_conv8)
up_conv8 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv8)
up9 = concatenate([up_conv8, conv1], axis=3)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
model = Model(inputs=[inputs], outputs=[conv10])
model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
return model
class UnetModel:
def predict(self):
model = get_unet()
model.load_weights('weights2.h5')
test_reader = HDF5DatasetGenerator(dbPath=outputPath,batchSize=30)
test_iter = test_reader.generator()
fixed_test_images, fixed_test_masks = test_iter.__next__()
# print(model.evaluate(fixed_test_images, fixed_test_masks,BATCH_SIZE*5))
imgs_mask_test = model.predict(fixed_test_images, verbose=1)
test_reader.close()
print('-' * 30)
print('Saving predicted masks to files...')
print('-' * 30)
pred_dir = 'step2_train1'
if not os.path.exists(pred_dir):
os.mkdir(pred_dir)
i = 0
for image in imgs_mask_test:
image = (image[:, :, 0] * 255.).astype(np.uint8)
gt = (fixed_test_masks[i,:,:,0] * 255.).astype(np.uint8)
ini = (fixed_test_images[i,:,:,0] *255.).astype(np.uint8)
io.imsave(os.path.join(pred_dir, str(i) + '_ini.png'), ini)
io.imsave(os.path.join(pred_dir, str(i) + '_pred.png'), image)
io.imsave(os.path.join(pred_dir, str(i) + '_gt.png'), gt)
i += 1
def train_and_predict(self):
reader = HDF5DatasetGenerator(dbPath=outputPath,batchSize=BATCH_SIZE)
train_iter = reader.generator()
test_reader = HDF5DatasetGenerator(dbPath=val_outputPath,batchSize=BATCH_SIZE)
test_iter = test_reader.generator()
#
model = get_unet()
model_checkpoint = ModelCheckpoint('weights2.h5', monitor='val_loss', save_best_only=True)
model.fit_generator(train_iter,steps_per_epoch=int(TOTAL/BATCH_SIZE),verbose=1,epochs=500,shuffle=True,
validation_data=test_iter, validation_steps=int(TOTAL_VAL/BATCH_SIZE) ,callbacks=[model_checkpoint])
#
reader.close()
test_reader.close()
# print('-'*30)
# print('Loading and preprocessing test data...')
# print('-'*30)
#
# print('-'*30)
# print('Loading saved weights...')
# print('-'*30)
# model.load_weights('weights.h5')
#
# print('-'*30)
# print('Predicting masks on test data...')
# print('-'*30)
#
#
#
# # 不懂这儿为什么会是np格式
# imgs_mask_test = model.predict(fixed_test_images, verbose=1)
# np.save('imgs_mask_test.npy', imgs_mask_test)
#
# print('-' * 30)
# print('Saving predicted masks to files...')
# print('-' * 30)
# pred_dir = 'preds'
# if not os.path.exists(pred_dir):
# os.mkdir(pred_dir)
# i = 0
#
#
# for image in imgs_mask_test:
# image = (image[:, :, 0] * 255.).astype(np.uint8)
# gt = (fixed_test_masks[i,:,:,0] * 255.).astype(np.uint8)
# ini = (fixed_test_images[i,:,:,0] *255.).astype(np.uint8)
# io.imsave(os.path.join(pred_dir, str(i) + '_ini.png'), ini)
# io.imsave(os.path.join(pred_dir, str(i) + '_pred.png'), image)
# io.imsave(os.path.join(pred_dir, str(i) + '_gt.png'), gt)
# i += 1
#model = get_unet()
#model.summary()
unet = UnetModel()
#unet.train_and_predict()
unet.train_and_predict()
#print("test")