keras实现Unet进行字符定位与识别分类

#coding=utf-8

import cv2
import numpy as np
from keras.utils import to_categorical
from model.augmentations import randomHueSaturationValue, randomShiftScaleRotate, randomHorizontalFlip
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard
import matplotlib.pyplot as plt
from keras.preprocessing.image import img_to_array
from keras.utils.vis_utils import plot_model
from keras import backend as K
from keras.callbacks import ModelCheckpoint,Callback, EarlyStopping
class LossHistory(Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
    # def on_epoch_end(self, epoch, logs=None):
#unet model
def get_unet_128_muticlass(input_shape=(None, 128, 128, 3),
                 num_classes=1):
    inputs = Input(batch_shape=input_shape)#shape=input_shape)
    # 128

    down1 = Conv2D(64, (3, 3), padding='same')(inputs)
    down1 = BatchNormalization()(down1)
    down1 = Activation('relu')(down1)
    down1 = Conv2D(64, (3, 3), padding='same')(down1)
    down1 = BatchNormalization()(down1)
    down1 = Activation('relu')(down1)
    down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)
    # 64

    down2 = Conv2D(128, (3, 3), padding='same')(down1_pool)
    down2 = BatchNormalization()(down2)
    down2 = Activation('relu')(down2)
    down2 = Conv2D(128, (3, 3), padding='same')(down2)
    down2 = BatchNormalization()(down2)
    down2 = Activation('relu')(down2)
    down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)
    # 32

    down3 = Conv2D(256, (3, 3), padding='same')(down2_pool)
    down3 = BatchNormalization()(down3)
    down3 = Activation('relu')(down3)
    down3 = Conv2D(256, (3, 3), padding='same')(down3)
    down3 = BatchNormalization()(down3)
    down3 = Activation('relu')(down3)
    down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3)
    # 16

    down4 = Conv2D(512, (3, 3), padding='same')(down3_pool)
    down4 = BatchNormalization()(down4)
    down4 = Activation('relu')(down4)
    down4 = Conv2D(512, (3, 3), padding='same')(down4)
    down4 = BatchNormalization()(down4)
    down4 = Activation('relu')(down4)
    down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4)
    # 8

    center = Conv2D(1024, (3, 3), padding='same')(down4_pool)
    center = BatchNormalization()(center)
    center = Activation('relu')(center)
    center = Conv2D(1024, (3, 3), padding='same')(center)
    center = BatchNormalization()(center)
    center = Activation('relu')(center)
    # center

    up4 = UpSampling2D((2, 2))(center)
    up4 = concatenate([down4, up4], axis=3)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    # 16

    up3 = UpSampling2D((2, 2))(up4)
    up3 = concatenate([down3, up3], axis=3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    # 32

    up2 = UpSampling2D((2, 2))(up3)
    up2 = concatenate([down2, up2], axis=3)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    # 64

    up1 = UpSampling2D((2, 2))(up2)
    up1 = concatenate([down1, up1], axis=3)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    # 128

    classify = Conv2D(num_classes, (1, 1), activation='softmax')(up1)

    model = Model(inputs=inputs, outputs=classify)

    model.compile(optimizer=RMSprop(lr=0.001), loss=categorical_crossentropy, metrics=['acc'])

    return model
print('model summary...')model = get_unet_128_muticlass(num_classes=3)model.summary()plot_model(model,'model.png', show_shapes=True)SIZE = (128, 128)def fix_mask(mask): mask[mask < 100] = 0 mask[mask == 128] = 128 mask[mask > 128] = 255def fix_mask_onehot(mask): mask[mask < 100] = 0 mask[mask == 128] = 1 mask[mask > 128] = 2# Processing function for the training datadef train_process(data): img, mask = data img = img[:,:,:3] mask = mask[:, :, :3] mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) fix_mask(mask) img = cv2.resize(img, SIZE) mask = cv2.resize(mask, SIZE) img = randomHueSaturationValue(img, hue_shift_limit=(-50, 50), sat_shift_limit=(0, 0), val_shift_limit=(-15, 15)) img, mask = randomShiftScaleRotate(img, mask, shift_limit=(-0.0625, 0.0625), scale_limit=(-0.1, 0.1), rotate_limit=(-20, 20)) img, mask = randomHorizontalFlip(img, mask) fix_mask(mask) img = img/255. # mask = mask/255. # mask = np.expand_dims(mask, axis=2) # mask =np.reshape(mask, (16384,1)) # print(np.shape(mask)) # fix_mask_onehot(mask) # # print(list(mask)) # mask =to_categorical(mask,num_classes=3) # print(np.shape(mask)) # mask = np.expand_dims(mask, axis=2) mask_onehot=[] for i in range(128): for j in range(128): if mask[i,j]>200: mask_onehot.append([1,0,0]) elif mask[i,j]<100: mask_onehot.append([0,0,1]) else: mask_onehot.append([0,1,0]) mask_onehot=np.reshape(mask_onehot, (128,128,3)) # print(np.shape(mask_onehot)) return (img, mask_onehot)# x=cv2.imread(r'data\train3_muticlass\test.tif',cv2.IMREAD_COLOR)# y=cv2.imread(r'data\train3_muticlass\mask\test.tif',cv2.IMREAD_COLOR)# x=255-x# y=255-y# # cv2.imshow('x',x)# # cv2.imshow('y',y)# # print(np.shape(x))# (xx,yy)=train_process((x,y))# exit()# Processing function for the validation data, no data augmentationdef validation_process(data): img, mask = data img = img[:,:,:3] mask = mask[:, :, :3] mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) fix_mask(mask) img = cv2.resize(img, SIZE) mask = cv2.resize(mask, SIZE) fix_mask(mask) img = img/255. # mask = mask/255. mask_onehot=[] for i in range(128): for j in range(128): if mask[i,j]>200: mask_onehot.append([1,0,0]) elif mask[i,j]<100: mask_onehot.append([0,0,1]) else: mask_onehot.append([0,1,0]) mask_onehot=np.reshape(mask_onehot, (128,128,3)) # print(np.shape(mask_onehot)) return (img, mask_onehot)dir=r'data\train3_muticlass'epochs=500# model.load_weights('weights/best_weights.hdf5')import glob,osx_train = []y_train = []for file in glob.glob(dir+r'\*.tif'): print(file) print(dir+r'\mask'+os.path.split(file)[1]) x = cv2.imread(file, cv2.IMREAD_COLOR) y = cv2.imread(dir+r'\mask\\'+os.path.split(file)[1], cv2.IMREAD_COLOR) x = 255 - x # y = 255 - y for i in range(30): (xx, yy) = train_process((x, y)) x_train.append(xx) y_train.append(yy)x_val = []y_val = []for file in glob.glob(dir+r'\*.tif'): print(file) print(dir+r'\mask'+os.path.split(file)[1]) x = cv2.imread(file, cv2.IMREAD_COLOR) y = cv2.imread(dir+r'\mask\\'+os.path.split(file)[1], cv2.IMREAD_COLOR) x = 255 - x # y = 255 - y (xx, yy) = validation_process((x, y)) x_val.append(xx) y_val.append(yy)# print(np.shape(x_train))# print(np.shape(y_train))# 用于保存验证集误差最小的参数,当验证集误差减少时,立马保存下来checkpointer = ModelCheckpoint(filepath="model_muticlass.w", verbose=0, save_best_only=True, save_weights_only=True) #save_weights_only=Truehistory = LossHistory()earlystop = EarlyStopping(patience=5)model.load_weights('model_muticlass.w')model.fit(np.array(x_train), np.array(y_train), epochs=100, batch_size=3,verbose=1, validation_data=(np.array(x_val), np.array(y_val)), callbacks=[checkpointer, history, earlystop] )model.load_weights('model_muticlass.w')for file in glob.glob(dir+r'\*.tif'): x = cv2.imread(file, cv2.IMREAD_COLOR) y = cv2.imread(dir+r'\mask\\'+os.path.split(file)[1], cv2.IMREAD_COLOR) x = 255 - x # y = 255 - y (xx,yy)=validation_process((x,y)) x_val = [] x_val.append(xx) print(np.shape(x_val)) # xx=np.expand_dims(xx, axis=0) # yy=np.expand_dims(yy, axis=0) predicted_mask_batch = model.predict(np.array(x_val)) print(np.shape(predicted_mask_batch)) predicted_mask = predicted_mask_batch[0].reshape((128,128,3)) plt.imshow(xx[0]) plt.imshow(predicted_mask[:,:,2], alpha=0.6) plt.show()K.clear_session()
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值