UNet多类别分割的keras实现

本文包含制作数据集、训练、推理、测试图像及结果四部分内容

目录

制作数据集

训练

推理

测试图像及结果

制作数据集

该数据集包含420张224*400图像,图像由画图工具产生,包含圆形、矩形和背景三种类别,选用不同的颜色进行填充。部分训练图像和标签图像如下图所示

 

训练

根据所填充颜色,将每张标注图像生成为rows*cols*class_nums的形式

from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img, save_img
import numpy as np
import os


colorDict = [[0, 0, 0], [34, 177, 76], [237, 28, 36]] ###背景、圆形、矩形的填充色
colorDict_RGB = np.array(colorDict)
colorDict_GRAY = colorDict_RGB[:, 0]
num_classes = 3


def data_preprocess(label, class_num):
    for i in range(colorDict_GRAY.shape[0]):
        label[label == colorDict_GRAY[i]] = i

    new_label = np.zeros(label.shape + (class_num,))

    for i in range(class_num):
        new_label[label == i, i] = 1
    label = new_label
    return label


def visual(array):
    for j in range(num_classes):
        vis = array[:, :, j]
        vis = vis*255
        vis = vis.reshape(224, 400, 1)
        vis_out = array_to_img(vis)
        vis_out.show()


class dataProcess(object):
    def __init__(self, out_rows, out_cols, data_path="../train1", label_path="../label1",
                    test_path="../test1", npy_path="../npydata", img_type="bmp"):

        # 数据处理类,初始化
        self.out_rows = out_rows
        self.out_cols = out_cols
        self.data_path = data_path
        self.label_path = label_path
        self.img_type = img_type
        self.test_path = test_path
        self.npy_path = npy_path
        self.num_classes = num_classes

    # 创建训练数据
    def create_train_data(self):
        print('-' * 30)
        print('Creating training images...')
        print('-' * 30)
        img_list = os.listdir(self.data_path)

        imgdatas = np.ndarray((len(img_list), self.out_rows, self.out_cols, 1), dtype=np.uint8)
        imglabels = np.ndarray((len(img_list), self.out_rows, self.out_cols, self.num_classes), dtype=np.uint8)

        for i in range(len(img_list)):
            img = load_img(self.data_path + "/" + img_list[i], color_mode="grayscale")
            img = img_to_array(img)
            imgdatas[i] = img

            label = load_img(self.label_path + "/" + img_list[i])
            label = img_to_array(label)[:, :, 0]
            label = data_preprocess(label, num_classes)
            # visual(label)
            imglabels[i] = label
        np.save(self.npy_path + '/imgs_train.npy', imgdatas)
        np.save(self.npy_path + '/imgs_mask_train.npy', imglabels)
        print('Saving to .npy files done.')

    def load_train_data(self):
        print('-' * 30)
        print('load train images...')
        print('-' * 30)
        imgs_train = np.load(self.npy_path + "/imgs_train.npy")
        imgs_mask_train = np.load(self.npy_path + "/imgs_mask_train.npy")
        imgs_train = imgs_train.astype('float32')
        imgs_mask_train = imgs_mask_train.astype('float32')
        imgs_train /= 255.0
        imgs_mask_train /= 255.0
        return imgs_train, imgs_mask_train

    def create_test_data(self):
        test_list = []
        print('-' * 30)
        print('Creating test images...')
        print('-' * 30)
        img_list = os.listdir(self.test_path)
        testdatas = np.ndarray((len(img_list), self.out_rows, self.out_cols, 1), dtype=np.uint8)

        for i in range(len(img_list)):
            img = load_img(self.test_path + "/" + img_list[i], color_mode="grayscale")
            img = img_to_array(img)
            testdatas[i] = img
            test_list.append(img_list[i])

        np.save(self.npy_path + '/imgs_test.npy', testdatas)
        print('Saving to .npy files done.')
        return test_list

    def load_test_data(self):
        print('-' * 30)
        print('load test images...')
        print('-' * 30)
        imgs_test = np.load(self.npy_path + "/imgs_test.npy")
        imgs_test = imgs_test.astype('float32')
        imgs_test /= 255.0
        return imgs_test


if __name__ == "__main__":
    mydata = dataProcess(224, 400)
    mydata.create_train_data()
    imgs_train, imgs_mask_train = mydata.load_train_data()
    print(imgs_train.shape, imgs_mask_train.shape)

将conv10中的类别数目修改为class_nums,将激活函数修改为softmax,将loss函数修改为'categorical_crossentropy'

import numpy as np
from keras.models import *
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint
from keras import backend as keras
from my_test.data import *
from keras.models import Model


class myUnet(object):
    def __init__(self, img_rows = 224, img_cols = 400):
        self.img_rows = img_rows
        self.img_cols = img_cols

    def load_data(self):

        mydata = dataProcess(self.img_rows, self.img_cols)
        imgs_train, imgs_mask_train = mydata.load_train_data()
        return imgs_train, imgs_mask_train

    def get_unet(self):
        inputs = Input((self.img_rows, self.img_cols, 1))
        # 网络结构定义
        '''
        #unet with crop(because padding = valid) 
        
        conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(inputs)
        print "conv1 shape:",conv1.shape
        conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv1)
        print "conv1 shape:",conv1.shape
        crop1 = Cropping2D(cropping=((90,90),(90,90)))(conv1)
        print "crop1 shape:",crop1.shape
        pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
        print "pool1 shape:",pool1.shape
        
        conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool1)
        print "conv2 shape:",conv2.shape
        conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv2)
        print "conv2 shape:",conv2.shape
        crop2 = Cropping2D(cropping=((41,41),(41,41)))(conv2)
        print "crop2 shape:",crop2.shape
        pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
        print "pool2 shape:",pool2.shape
        
        conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool2)
        print "conv3 shape:",conv3.shape
        conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv3)
        print "conv3 shape:",conv3.shape
        crop3 = Cropping2D(cropping=((16,17),(16,17)))(conv3)
        print "crop3 shape:",crop3.shape
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
        print "pool3 shape:",pool3.shape
        
        conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool3)
        conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv4)
        drop4 = Dropout(0.5)(conv4)
        crop4 = Cropping2D(cropping=((4,4),(4,4)))(drop4)
        pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
        
        conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool4)
        conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv5)
        drop5 = Dropout(0.5)(conv5)
        
        up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
        merge6 = merge([crop4,up6], mode = 'concat', concat_axis = 3)
        conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge6)
        conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv6)
        
        up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
        merge7 = merge([crop3,up7], mode = 'concat', concat_axis = 3)
        conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge7)
        conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv7)
        
        up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
        merge8 = merge([crop2,up8], mode = 'concat', concat_axis = 3)
        conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge8)
        conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv8)
        
        up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
        merge9 = merge([crop1,up9], mode = 'concat', concat_axis = 3)
        conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge9)
        conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)
        conv9 = Conv2D(2, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)
        '''
        conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
        conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
        pool1 = MaxPooling2D((2, 2))(conv1)
        conv2 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
        conv2 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
        pool2 = MaxPooling2D((2, 2))(conv2)
        conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
        conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
        conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
        conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
        drop4 = Dropout(0.5)(conv4)
        pool4 = MaxPooling2D((2, 2))(drop4)
        conv5 = Conv2D(1024, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
        conv5 = Conv2D(1024, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
        drop5 = Dropout(0.5)(conv5)
        up6 = Conv2D(512, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
        merge6 = concatenate([drop4, up6], axis = 3)
        conv6 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
        conv6 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
        up7 = Conv2D(256, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
        merge7 = concatenate([conv3, up7], axis = 3)
        conv7 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
        conv7 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
        up8 = Conv2D(128, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
        merge8 = concatenate([conv2, up8], axis = 3)
        conv8 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
        conv8 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
        up9 = Conv2D(64, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
        merge9 = concatenate([conv1, up9], axis = 3)
        conv9 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
        conv9 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
        conv9 = Conv2D(2, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
        # conv10 = Conv2D(1, (1,1), activation = 'sigmoid')(conv9)
        conv10 = Conv2D(class_nums, (1,1), activation = 'softmax')(conv9)
        model = Model(inputs = inputs, outputs = conv10)
        model.summary()
        # model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
        model.compile(optimizer = Adam(lr = 1e-4), loss = 'categorical_crossentropy', metrics = ['accuracy'])
        return model

    def train(self):
        print("loading data")
        imgs_train, imgs_mask_train = self.load_data()
        print("loading data done")
        model = self.get_unet()
        print("got unet")
        model_checkpoint = ModelCheckpoint('my_unet.hdf5', monitor='loss',verbose=1, save_best_only=True)
        print('Fitting model...')
        model.fit(imgs_train, imgs_mask_train, batch_size=4, epochs=10, verbose=1, validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])


if __name__ == '__main__':
    class_nums = 3
    myunet = myUnet()
    myunet.train()


推理

将推理结果拆分为3通道图像,分别显示各通道图像

from my_test.data import *
import numpy as np
from keras.models import load_model
from keras.preprocessing.image import array_to_img


def save_img(test_list):
    print("array to image")
    imgs = np.load('../11/imgs_mask_test.npy')
    for i in range(imgs.shape[0]):
        img = imgs[i]
        for j in range(class_num):
            out = img[:, :, j]
            out = out.reshape(224, 400, 1)
            out = array_to_img(out)
            out.save("../11/" + str(j) + '_' + test_list[i])


unet_model_path = 'my_unet.hdf5'
model = load_model(unet_model_path)
class_num = 3
mydata = dataProcess(224, 400)
imgs_test = mydata.load_test_data()
test_list = mydata.create_test_data()
imgs_mask_test = model.predict(imgs_test, batch_size=1, verbose=1)
np.save('../11/imgs_mask_test.npy', imgs_mask_test)
save_img(test_list)

测试图像及结果

网络输入为rows*cols*1,输出为rows*cols*class_nums。在数据处理阶段,将通道0中的背景设为mask区域,通道1中的圆形设置为mask区域,通道2中的矩形设置为mask区域,因此对输出的三个通道进行拆分得到:通道0为背景的分割结果,通道1为圆形的分割结果,通道2为矩形的分割结果

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值