医学图像分割 基于深度学习的肝脏肿瘤分割 实战(二)

医学图像分割 基于深度学习的肝脏肿瘤分割 实战(一)中,实现了对肝脏的分割,但是后续在使用相同的处理方法与模型进行肿瘤分割的时候,遇到了两次问题。

第一次,网络的dice系数,训练集上一直只能达到40%左右,测试集上只有10%左右,而且结果明显不对。我认为是数据层面有问题,随机又看了很多原始图片,发现很多肿瘤肉眼都无法分辨,于是进行了实验验证猜想,实验写在了博客:医学图像预处理(五) 器官与病灶的直方图即验证很多病人的肝脏和肿瘤灰度值(更准确地说是hu值)几乎是重叠的

实验验证了猜想,开始重新做实验。我于是从3Dircadb数据库换成了LiTS2017数据库,后者的数据量更大。然后根据病人的肝脏与肿瘤直方图分布,手动将130位病人划分成了三个等级:level1:肝脏与肿瘤对比度最大。level3:肝脏与肿瘤的对比度几乎没有。训练数据只选择level1&level2。

当然,数据预处理操作和ROI操作和之前一样,是都有的。
但是这回做实验,一样出了问题。即第二次问题。

第二次:先是dice系数在训练集上可以到10%,可是又会突然回到接近0的值。我百思不得其解,以为是发生了梯度消失或者梯度爆炸的问题,还差了很多资料看如何解决。
可是,后来想着想着,发现一个奇怪又严肃的问题,就是ROI操作会将真实肝脏分割结果与原图做“与”操作,这样,非感兴趣的区域就会变黑,就可以让网络集中注意力在肝脏内部的区域,可是,肝脏内部的肿瘤也是黑色(一般灰度值低于肝脏)的呀??为什么外面变成黑色就是“非感兴趣区域”,而肝脏里的就得是“目标”呢? 想了想之前肝脏分割的图片,发现做窗口值等的操作,也是将非目标区域变黑,然后可以突出肝脏(肝脏变成灰白色,当然,还有其他器官更是白色)。
于是,我做了一个大胆的实验,将肝脏变成灰色,肿瘤变成白色,肝脏外的区域为黑色。(进行了颜色翻转
在这里插入图片描述
在这里插入图片描述
这一次实验,训练集上,dice系数90%左右,测试集上70%左右。虽然不够好,但是至少证明了猜想是可行的。也反映了,数据是王道:种瓜得瓜,种豆得豆。

下面是实验的代码,第一部分是准备数据集的过程(预处理后写成h5文件),这部分在本地环境进行。第二部分是模型构建和训练过程,这部分在服务器进行(ubuntu16.04, tensorflow-gpu)

第一部分:
(注:有关h5文件读写的工具类也放在了博客里)

# -*- coding: utf-8 -*-

"""
根据LITS_check.py,观察结果
根据肝脏与肿瘤的对比度,将病人分成 3 level
1 level:对比度最高(随机选出两个作为validation集)
2 level: 对比度中等(随机选出两个作为validation集)
3 level:对比度最低
"""
# theshold = 1e-3, total=755
# 81,125作为测试集
level_1 = [0,1,22,23,25,26,27,31,37,46,49,50,55,57,58,59,61,62,
           63,64,66,78,79,82,83,90,92,95,99,109,112,124]

#level_1 = [63,64,66,78,79,81,82,83]
# theshold = 1e-3, total= 1345
# 11,110作为测试集
level_2 = [2,7,8,9,10,12,14,15,17,28,35,40,42,
           53,56,69,76,93,96,101,111,113,117]

level12 = level_1 + level_2
level12.sort()

test_list = [11,81,110,125]

a = [i for i in range(130)]
level_3 =list(set(a)-set(level_1)-set(level_2))
# sort方法直接改变原列表,无返回值
level_3.sort()


"""
将level_1的其余图片观察,确定是否对比度高
观察后确定窗口值为:[-50,200]
"""
onServer = False
if onServer:
    niiSegPath = './LITS17/seg/'
    niiImagePath = './LITS17/ct/'
else:
    niiSegPath = '~/Documents/LITS17/seg/'
    niiImagePath = '~/Documents/LITS17/ct/'


import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt

def getRangeImageDepth(image):
    z = np.any(image, axis=(1,2)) # z.shape:(depth,)
    #print("all index:",np.where(z)[0])
    if len(np.where(z)[0]) >0:
        startposition,endposition = np.where(z)[0][[0,-1]]
    else:
        startposition = endposition = 0
    
    return startposition, endposition

def sample_stack(stack, name="images.png", rows=4, cols=2, start_with=0, show_every=1):
    fig,ax = plt.subplots(rows,cols,figsize=[5*cols,5*rows])
    if rows==1 or cols==1 :
        nums = rows*cols
        for i in range(nums):
            ind = start_with + i*show_every
            ax[int(i % nums)].set_title('slice %d' % ind)
            ax[int(i % nums)].imshow(stack[ind],cmap='gray')
            ax[int(i % nums)].axis('off')
    else:
        for i in range(rows*cols):
            ind = start_with + i*show_every
            ax[int(i/cols),int(i % cols)].set_title('slice %d' % ind)
            ax[int(i/cols),int(i % cols)].imshow(stack[ind],cmap='gray')
            ax[int(i/cols),int(i % cols)].axis('off')
    # 这句话一定要在show之前写,否则show函数之后会创建新的空白图
#    plt.savefig(name)
    plt.show()

"""
工具函数,左边原图,右边真实分割图
"""
def show_src_seg(srcimg, segimg,index, rows=3,start_with=0, show_every=1):
    assert srcimg.shape == segimg.shape
    
    rows = srcimg.shape[0]
    plan_rows = start_with + rows*show_every - 1
    print("rows=%d,planned_rows=%d"%(rows,plan_rows))
    
    rows = plan_rows if (rows > plan_rows) else rows
    cols = 2
    print("final rows=%d"%rows)
    
    fig,ax = plt.subplots(rows,cols,figsize=[5*cols,5*rows])
    for i in range(rows):
        ind = start_with + i*show_every
        ax[i,0].set_title('src slice %d' % ind)
        ax[i,0].imshow(srcimg[ind],cmap='gray')
        ax[i,0].axis('off')
        
        ax[i,1].set_title('truth seg slice %d' % ind)
        ax[i,1].imshow(segimg[ind],cmap='gray')
        ax[i,1].axis('off')
    # 这句话一定要在show之前写,否则show函数之后会创建新的空白图
    name = "../LITS/crop/"+str(index)+".png"
    plt.savefig(name)
#    plt.show()


def transform_ctdata(image, windowWidth, windowCenter, normal=False):
        """
        注意,这个函数的self.image一定得是float类型的,否则就无效!
        return: trucated image according to window center and window width
        """
        minWindow = float(windowCenter) - 0.5*float(windowWidth)
        newimg = (image - minWindow) / float(windowWidth)
        newimg[newimg < 0] = 0
        newimg[newimg > 1] = 1
        if not normal:
            newimg = (newimg * 255).astype('uint8')
        return newimg

import cv2   
def clahe_equalized(imgs):
    assert (len(imgs.shape)==3)  #3D arrays
    #create a CLAHE object (Arguments are optional).
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    imgs_equalized = np.empty(imgs.shape)
    for i in range(len(imgs)):
        imgs_equalized[i,:,:] = clahe.apply(np.array(imgs[i,:,:], dtype = np.uint8))
    return imgs_equalized

# 根据肝脏真实分割图,将原始图片进行裁剪为 以肝脏为中心,指定宽、高的图片
def crop_images_func(refer_images, target_images, target_tumors):
    maxw=maxh=0
    assert refer_images.shape == target_images.shape == target_tumors.shape
    
    crop_images = []
    crop_tumors = []
    
    for i in range(refer_images.shape[0]):
        # Create figure and axes
#        fig,ax = plt.subplots(1)
        
        mask = refer_images[i]
        
        # find coordinates of liver
        coor = np.nonzero(mask) 
        xmin = coor[0][0] # x代表了行
        xmax = coor[0][-1]
        coor[1].sort() # 直接改变原数组,没有返回值
        ymin = coor[1][0]
        ymax = coor[1][-1]
        
        width_center = (ymax + ymin) // 2
        height_center = (xmax + xmin) // 2
        
        # pre-parameter: height:266, width:334
        # 参数的选定:是之前随机运行后,挑出的最大值,然后适当扩大后的结果
        height = 280
        width = 360
        istart = int(height_center - height/2)
        
        #注意逻辑!
        if istart < 0:
            istart = 0
            iend = height
        else:
            iend = int(istart + height)
        if iend > 512:
            istart = 512 - height
            iend = 512
            
        jstart = int(width_center - width/2)
        if jstart < 0:
            jstart = 0
            jend = width
            
        jend = int(jstart + width)
        
        if jend > 512:
            jstart = 512 - width
            jend = 512
    
#        print("[%d:%d,%d:%d]"%(istart,iend,jstart,jend))

        mask_crop = target_images[i,istart:iend,jstart:jend]   
        tumors_crop = target_tumors[i,istart:iend,jstart:jend]
        
#        ax.imshow(mask_crop,cmap=plt.cm.gray)
        
        crop_images.append(mask_crop)
        crop_tumors.append(tumors_crop)
        
    crop_images = np.asarray(crop_images)
    crop_tumors = np.asarray(crop_tumors)
    return (crop_images,crop_tumors)

"""
训练数据
第一步:读取数据
第二步:找到具有肿瘤的切片(具有肿瘤的切片一定是肝脏也在的)
第三步:预处理
       窗口化、自适应直方图均衡化、归一化、颜色翻转、ROI
第四步:裁剪
第五步:将数据写入文件
"""
# 工具类在博客里有写
from HDF5DatasetWriter import HDF5DatasetWriter

dataset = HDF5DatasetWriter(image_dims=(1967, 280, 360, 1),
                            mask_dims=(1967, 280, 360, 1),
                            outputPath="../data_train/LITS_train_tumor_crop.h5")


count = 0
for i in level12:
    seg = sitk.ReadImage(niiSegPath+ "segmentation-" + str(i) + ".nii", sitk.sitkUInt8)
    segimg = sitk.GetArrayFromImage(seg)
    src = sitk.ReadImage(niiImagePath+"volume-" + str(i) + ".nii")
    srcimg = sitk.GetArrayFromImage(src)

    seg_liver = segimg.copy()
    seg_liver[seg_liver>0] = 1

    seg_tumorimage = segimg.copy()
    seg_tumorimage[segimg == 1] = 0
    seg_tumorimage[segimg == 2] = 1
    
    # 只选择ROI区域
    srcimg = srcimg * seg_liver
    
    start,end = getRangeImageDepth(seg_tumorimage)
    if start==0 and end == 0:
        print("continue")
        continue
    print("start:",start," end:",end)
    
    theshold = 1e-3 # 最小阈值
    
    filter_index = []
    
    for j in range(start, end+1):
        if np.mean(seg_tumorimage[j]) > theshold:
            filter_index.append(j)
            
    if len(filter_index)<1:
        continue
    
    count += len(filter_index)
    
#    print("picked index:",filter_index)
   
    srcimg = srcimg[filter_index]
    seg_liver = seg_liver[filter_index]
    seg_tumorimage = seg_tumorimage[filter_index]
#    
    srcimg = transform_ctdata(srcimg, 250,75,normal=False)
    srcimg = clahe_equalized(srcimg)
    srcimg /= 255.
    
    # 注意,下面这两步顺序一定不能变,否则就不能达到正确的颜色翻转效果了
    srcimg = 1- srcimg
    # 只选择ROI区域
    srcimg = srcimg * seg_liver
    
    crop_images,crop_tumors = crop_images_func(seg_liver,srcimg,seg_tumorimage)
    
#    show_src_seg(crop_images,crop_tumors,index=i)
    
    crop_images = np.expand_dims(crop_images,axis=-1)
    crop_tumors = np.expand_dims(crop_tumors,axis=-1)

    
    dataset.add(crop_images,crop_tumors)

print(dataset.close())
    
    
dataset = HDF5DatasetWriter(image_dims=(133, 280, 360, 1),
                            mask_dims=(133, 280, 360, 1),
                            outputPath="../data_train/LITS_val_tumor_crop.h5")


count = 0
for i in test_list:
    seg = sitk.ReadImage(niiSegPath+ "segmentation-" + str(i) + ".nii", sitk.sitkUInt8)
    segimg = sitk.GetArrayFromImage(seg)
    src = sitk.ReadImage(niiImagePath+"volume-" + str(i) + ".nii")
    srcimg = sitk.GetArrayFromImage(src)

    seg_liver = segimg.copy()
    seg_liver[seg_liver>0] = 1

    seg_tumorimage = segimg.copy()
    seg_tumorimage[segimg == 1] = 0
    seg_tumorimage[segimg == 2] = 1
    
    
    
    
    start,end = getRangeImageDepth(seg_tumorimage)
    if start==0 and end == 0:
        print("continue")
        continue
    print("start:",start," end:",end)
    
    theshold = 1e-3 # 最小阈值
    
    filter_index = []
    
    for j in range(start, end+1):
        if np.mean(seg_tumorimage[j]) > theshold:
            filter_index.append(j)
            
    if len(filter_index)<1:
        continue
    
    count += len(filter_index)
    
    
#    print("picked index:",filter_index)
   
    srcimg = srcimg[filter_index]
    seg_liver = seg_liver[filter_index]
    seg_tumorimage = seg_tumorimage[filter_index]
    
    srcimg = transform_ctdata(srcimg, 250,75,normal=False)
    srcimg = clahe_equalized(srcimg)
    srcimg /= 255.
    
    srcimg = 1- srcimg
    # 只选择ROI区域
    srcimg = srcimg * seg_liver
    
    
    crop_images,crop_tumors = crop_images_func(seg_liver,srcimg,seg_tumorimage)
    
    show_src_seg(crop_images,crop_tumors,index=i)
    
    crop_images = np.expand_dims(crop_images,axis=-1)
    crop_tumors = np.expand_dims(crop_tumors,axis=-1)

    
    dataset.add(crop_images,crop_tumors)

print(dataset.close())

""" 
# 测试
from HDF5DatasetGenerator import HDF5DatasetGenerator

outputPath = '../data_train/LITS_train_tumor_crop.h5'
val_outputPath = '../data_train/LITS_val_tumor_crop.h5'
BATCH_SIZE = 8

reader = HDF5DatasetGenerator(dbPath=val_outputPath,batchSize=BATCH_SIZE)
train_iter = reader.generator()

src,seg = train_iter.__next__()

src = np.squeeze(src)
seg = np.squeeze(seg)

sample_stack(src)
sample_stack(seg)
"""

第二部分:

# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
import random
import math
import tensorflow as tf
from HDF5DatasetGenerator import HDF5DatasetGenerator
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose,Cropping2D,ZeroPadding2D
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras import backend as K
from skimage import io
from keras import losses

# Set some parameters
IMG_WIDTH = 360
IMG_HEIGHT = 280
IMG_CHANNELS = 1
TOTAL = 1967 # 总共的训练数据
TOTAL_VAL = 133 # 总共的validation数据
outputPath = '../data_train/LITS_train_tumor_crop.h5' # 训练文件
val_outputPath = '../data_train/LITS_val_tumor_crop.h5'
#checkpoint_path = 'model.ckpt'
BATCH_SIZE = 4

K.set_image_data_format('channels_last')
    
def dice_coef(y_true, y_pred):
    print("in loss function, y_true shape:",y_true.shape)
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

# 疑问,不知道除n的操作是否该写?还是说keras会自动取平均
def weighted_binary_cross_entropy_loss(y_true, y_pred):
    """
    # 跟标准的结果差不多 0.068760,该结果:0.0685122
    print("y_pred shape ",K.int_shape(y_pred))
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    ce = - K.mean(y_true*K.log(K.epsilon()+y_pred) + (1-y_true)*K.log(1-y_pred+K.epsilon()))
    return ce
    """
    
    """
    # 跟标准结果一样
    b_ce = K.binary_crossentropy(y_true, y_pred)
    return b_ce
    """
    # 不确定是否正确
    
    # Calculate the binary crossentropy
    b_ce = K.binary_crossentropy(y_true, y_pred)
    one_weight = K.mean(y_true)
    zero_weight = 1 - one_weight
#    weight = zero_weight / one_weight
    # Apply the weights
    weight_vector = y_true * zero_weight  + (1. - y_true) * one_weight
    weighted_b_ce = weight_vector * b_ce

    # Return the mean error
    return K.mean(weighted_b_ce)

# 不确定是否正确?
def weighted_dice_loss(y_true, y_pred):
    mean = K.mean(y_true)
    w_1 = 1/mean**2
    w_0 = 1/(1-mean)**2
    y_true_f_1 = K.flatten(y_true)
    y_pred_f_1 = K.flatten(y_pred)
    y_true_f_0 = K.flatten(1-y_true)
    y_pred_f_0 = K.flatten(1-y_pred)
    
    intersection_0 = K.sum(y_true_f_0 * y_pred_f_0)
    intersection_1 = K.sum(y_true_f_1 * y_pred_f_1)

    return -2 * (w_0 * intersection_0 +w_1 * intersection_1)\
          / ((w_0 * (K.sum(y_true_f_0) + K.sum(y_pred_f_0))) \
             + (w_1 * (K.sum(y_true_f_1) + K.sum(y_pred_f_1))))

def get_crop_shape(target, refer):
        # width, the 3rd dimension
#        print(target.shape)
#        print(refer._keras_shape)
        cw = (target._keras_shape[2] - refer._keras_shape[2])
        assert (cw >= 0)
        if cw % 2 != 0:
            cw1, cw2 = int(cw/2), int(cw/2) + 1
        else:
            cw1, cw2 = int(cw/2), int(cw/2)
        # height, the 2nd dimension
        ch = (target._keras_shape[1] - refer._keras_shape[1])
        assert (ch >= 0)
        if ch % 2 != 0:
            ch1, ch2 = int(ch/2), int(ch/2) + 1
        else:
            ch1, ch2 = int(ch/2), int(ch/2)

        return (ch1, ch2), (cw1, cw2)

def get_unet():
    inputs = Input((IMG_HEIGHT, IMG_WIDTH , 1))
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    up_conv5 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5)
    
    ch, cw = get_crop_shape(conv4, up_conv5)
#    print("ch,cw",ch,cw)
#    
    up_conv5 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv5)
    up6 = concatenate([up_conv5, conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
    
    up_conv6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6)
    
    ch, cw = get_crop_shape(conv3, up_conv6)
    up_conv6 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv6)
#    
    up7 = concatenate([up_conv6, conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
    
    up_conv7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7)
    ch, cw = get_crop_shape(conv2, up_conv7)
    up_conv7 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv7)

    up8 = concatenate([up_conv7, conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
    
    up_conv8 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    ch, cw = get_crop_shape(conv1, up_conv8)
    up_conv8 = ZeroPadding2D(padding=(ch,cw), data_format="channels_last")(up_conv8)

    up9 = concatenate([up_conv8, conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])

    model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])

    return model



class UnetModel:
    
    def predict(self):
        model = get_unet()
        model.load_weights('weights2.h5')
        test_reader = HDF5DatasetGenerator(dbPath=outputPath,batchSize=30)
        test_iter = test_reader.generator()
        fixed_test_images, fixed_test_masks = test_iter.__next__()
#        print(model.evaluate(fixed_test_images, fixed_test_masks,BATCH_SIZE*5))
        
        imgs_mask_test = model.predict(fixed_test_images, verbose=1)
        test_reader.close()
        print('-' * 30)
        print('Saving predicted masks to files...')
        print('-' * 30)
        pred_dir = 'step2_train1'
        if not os.path.exists(pred_dir):
            os.mkdir(pred_dir)
        i = 0
        
        
        for image in imgs_mask_test:
            image = (image[:, :, 0] * 255.).astype(np.uint8)
            gt = (fixed_test_masks[i,:,:,0] * 255.).astype(np.uint8)
            ini = (fixed_test_images[i,:,:,0] *255.).astype(np.uint8)
            io.imsave(os.path.join(pred_dir, str(i) + '_ini.png'), ini)
            io.imsave(os.path.join(pred_dir, str(i) + '_pred.png'), image)
            io.imsave(os.path.join(pred_dir, str(i) + '_gt.png'), gt)
            i += 1
        
    
    def train_and_predict(self):
        
        reader = HDF5DatasetGenerator(dbPath=outputPath,batchSize=BATCH_SIZE)
        train_iter = reader.generator()
        
        test_reader = HDF5DatasetGenerator(dbPath=val_outputPath,batchSize=BATCH_SIZE)
        test_iter = test_reader.generator()
        
#   
        
        model = get_unet()
        model_checkpoint = ModelCheckpoint('weights2.h5', monitor='val_loss', save_best_only=True)
        model.fit_generator(train_iter,steps_per_epoch=int(TOTAL/BATCH_SIZE),verbose=1,epochs=500,shuffle=True,
                            validation_data=test_iter, validation_steps=int(TOTAL_VAL/BATCH_SIZE) ,callbacks=[model_checkpoint])
#        
        reader.close()
        test_reader.close()
        
        
#        print('-'*30)
#        print('Loading and preprocessing test data...')
#        print('-'*30)
#        
#        print('-'*30)
#        print('Loading saved weights...')
#        print('-'*30)
#        model.load_weights('weights.h5')
#    
#        print('-'*30)
#        print('Predicting masks on test data...')
#        print('-'*30)
#        
#        
#        
#        # 不懂这儿为什么会是np格式
#        imgs_mask_test = model.predict(fixed_test_images, verbose=1)
#        np.save('imgs_mask_test.npy', imgs_mask_test)
#    
#        print('-' * 30)
#        print('Saving predicted masks to files...')
#        print('-' * 30)
#        pred_dir = 'preds'
#        if not os.path.exists(pred_dir):
#            os.mkdir(pred_dir)
#        i = 0
#        
#        
#        for image in imgs_mask_test:
#            image = (image[:, :, 0] * 255.).astype(np.uint8)
#            gt = (fixed_test_masks[i,:,:,0] * 255.).astype(np.uint8)
#            ini = (fixed_test_images[i,:,:,0] *255.).astype(np.uint8)
#            io.imsave(os.path.join(pred_dir, str(i) + '_ini.png'), ini)
#            io.imsave(os.path.join(pred_dir, str(i) + '_pred.png'), image)
#            io.imsave(os.path.join(pred_dir, str(i) + '_gt.png'), gt)
#            i += 1

#model = get_unet()
#model.summary()
unet = UnetModel()
#unet.train_and_predict()
unet.train_and_predict()
#print("test")
   
        
  • 43
    点赞
  • 275
    收藏
    觉得还不错? 一键收藏
  • 28
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 28
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值