#coding=utf-8 import cv2 import numpy as np from keras.utils import to_categorical from model.augmentations import randomHueSaturationValue, randomShiftScaleRotate, randomHorizontalFlip from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard import matplotlib.pyplot as plt from keras.preprocessing.image import img_to_array from keras.utils.vis_utils import plot_model from keras import backend as K from keras.callbacks import ModelCheckpoint,Callback, EarlyStopping class LossHistory(Callback): def on_train_begin(self, logs={}): self.losses = [] def on_batch_end(self, batch, logs={}): self.losses.append(logs.get('loss')) # def on_epoch_end(self, epoch, logs=None): #unet modeldef get_unet_128_muticlass(input_shape=(None, 128, 128, 3), num_classes=1): inputs = Input(batch_shape=input_shape)#shape=input_shape) # 128 down1 = Conv2D(64, (3, 3), padding='same')(inputs) down1 = BatchNormalization()(down1) down1 = Activation('relu')(down1) down1 = Conv2D(64, (3, 3), padding='same')(down1) down1 = BatchNormalization()(down1) down1 = Activation('relu')(down1) down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) # 64 down2 = Conv2D(128, (3, 3), padding='same')(down1_pool) down2 = BatchNormalization()(down2) down2 = Activation('relu')(down2) down2 = Conv2D(128, (3, 3), padding='same')(down2) down2 = BatchNormalization()(down2) down2 = Activation('relu')(down2) down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) # 32 down3 = Conv2D(256, (3, 3), padding='same')(down2_pool) down3 = BatchNormalization()(down3) down3 = Activation('relu')(down3) down3 = Conv2D(256, (3, 3), padding='same')(down3) down3 = BatchNormalization()(down3) down3 = Activation('relu')(down3) down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3) # 16 down4 = Conv2D(512, (3, 3), padding='same')(down3_pool) down4 = BatchNormalization()(down4) down4 = Activation('relu')(down4) down4 = Conv2D(512, (3, 3), padding='same')(down4) down4 = BatchNormalization()(down4) down4 = Activation('relu')(down4) down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4) # 8 center = Conv2D(1024, (3, 3), padding='same')(down4_pool) center = BatchNormalization()(center) center = Activation('relu')(center) center = Conv2D(1024, (3, 3), padding='same')(center) center = BatchNormalization()(center) center = Activation('relu')(center) # center up4 = UpSampling2D((2, 2))(center) up4 = concatenate([down4, up4], axis=3) up4 = Conv2D(512, (3, 3), padding='same')(up4) up4 = BatchNormalization()(up4) up4 = Activation('relu')(up4) up4 = Conv2D(512, (3, 3), padding='same')(up4) up4 = BatchNormalization()(up4) up4 = Activation('relu')(up4) up4 = Conv2D(512, (3, 3), padding='same')(up4) up4 = BatchNormalization()(up4) up4 = Activation('relu')(up4) # 16 up3 = UpSampling2D((2, 2))(up4) up3 = concatenate([down3, up3], axis=3) up3 = Conv2D(256, (3, 3), padding='same')(up3) up3 = BatchNormalization()(up3) up3 = Activation('relu')(up3) up3 = Conv2D(256, (3, 3), padding='same')(up3) up3 = BatchNormalization()(up3) up3 = Activation('relu')(up3) up3 = Conv2D(256, (3, 3), padding='same')(up3) up3 = BatchNormalization()(up3) up3 = Activation('relu')(up3) # 32 up2 = UpSampling2D((2, 2))(up3) up2 = concatenate([down2, up2], axis=3) up2 = Conv2D(128, (3, 3), padding='same')(up2) up2 = BatchNormalization()(up2) up2 = Activation('relu')(up2) up2 = Conv2D(128, (3, 3), padding='same')(up2) up2 = BatchNormalization()(up2) up2 = Activation('relu')(up2) up2 = Conv2D(128, (3, 3), padding='same')(up2) up2 = BatchNormalization()(up2) up2 = Activation('relu')(up2) # 64 up1 = UpSampling2D((2, 2))(up2) up1 = concatenate([down1, up1], axis=3) up1 = Conv2D(64, (3, 3), padding='same')(up1) up1 = BatchNormalization()(up1) up1 = Activation('relu')(up1) up1 = Conv2D(64, (3, 3), padding='same')(up1) up1 = BatchNormalization()(up1) up1 = Activation('relu')(up1) up1 = Conv2D(64, (3, 3), padding='same')(up1) up1 = BatchNormalization()(up1) up1 = Activation('relu')(up1) # 128 classify = Conv2D(num_classes, (1, 1), activation='softmax')(up1) model = Model(inputs=inputs, outputs=classify) model.compile(optimizer=RMSprop(lr=0.001), loss=categorical_crossentropy, metrics=['acc']) return modelprint('model summary...')model = get_unet_128_muticlass(num_classes=3)model.summary()plot_model(model,'model.png', show_shapes=True)SIZE = (128, 128)def fix_mask(mask): mask[mask < 100] = 0 mask[mask == 128] = 128 mask[mask > 128] = 255def fix_mask_onehot(mask): mask[mask < 100] = 0 mask[mask == 128] = 1 mask[mask > 128] = 2# Processing function for the training datadef train_process(data): img, mask = data img = img[:,:,:3] mask = mask[:, :, :3] mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) fix_mask(mask) img = cv2.resize(img, SIZE) mask = cv2.resize(mask, SIZE) img = randomHueSaturationValue(img, hue_shift_limit=(-50, 50), sat_shift_limit=(0, 0), val_shift_limit=(-15, 15)) img, mask = randomShiftScaleRotate(img, mask, shift_limit=(-0.0625, 0.0625), scale_limit=(-0.1, 0.1), rotate_limit=(-20, 20)) img, mask = randomHorizontalFlip(img, mask) fix_mask(mask) img = img/255. # mask = mask/255. # mask = np.expand_dims(mask, axis=2) # mask =np.reshape(mask, (16384,1)) # print(np.shape(mask)) # fix_mask_onehot(mask) # # print(list(mask)) # mask =to_categorical(mask,num_classes=3) # print(np.shape(mask)) # mask = np.expand_dims(mask, axis=2) mask_onehot=[] for i in range(128): for j in range(128): if mask[i,j]>200: mask_onehot.append([1,0,0]) elif mask[i,j]<100: mask_onehot.append([0,0,1]) else: mask_onehot.append([0,1,0]) mask_onehot=np.reshape(mask_onehot, (128,128,3)) # print(np.shape(mask_onehot)) return (img, mask_onehot)# x=cv2.imread(r'data\train3_muticlass\test.tif',cv2.IMREAD_COLOR)# y=cv2.imread(r'data\train3_muticlass\mask\test.tif',cv2.IMREAD_COLOR)# x=255-x# y=255-y# # cv2.imshow('x',x)# # cv2.imshow('y',y)# # print(np.shape(x))# (xx,yy)=train_process((x,y))# exit()# Processing function for the validation data, no data augmentationdef validation_process(data): img, mask = data img = img[:,:,:3] mask = mask[:, :, :3] mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) fix_mask(mask) img = cv2.resize(img, SIZE) mask = cv2.resize(mask, SIZE) fix_mask(mask) img = img/255. # mask = mask/255. mask_onehot=[] for i in range(128): for j in range(128): if mask[i,j]>200: mask_onehot.append([1,0,0]) elif mask[i,j]<100: mask_onehot.append([0,0,1]) else: mask_onehot.append([0,1,0]) mask_onehot=np.reshape(mask_onehot, (128,128,3)) # print(np.shape(mask_onehot)) return (img, mask_onehot)dir=r'data\train3_muticlass'epochs=500# model.load_weights('weights/best_weights.hdf5')import glob,osx_train = []y_train = []for file in glob.glob(dir+r'\*.tif'): print(file) print(dir+r'\mask'+os.path.split(file)[1]) x = cv2.imread(file, cv2.IMREAD_COLOR) y = cv2.imread(dir+r'\mask\\'+os.path.split(file)[1], cv2.IMREAD_COLOR) x = 255 - x # y = 255 - y for i in range(30): (xx, yy) = train_process((x, y)) x_train.append(xx) y_train.append(yy)x_val = []y_val = []for file in glob.glob(dir+r'\*.tif'): print(file) print(dir+r'\mask'+os.path.split(file)[1]) x = cv2.imread(file, cv2.IMREAD_COLOR) y = cv2.imread(dir+r'\mask\\'+os.path.split(file)[1], cv2.IMREAD_COLOR) x = 255 - x # y = 255 - y (xx, yy) = validation_process((x, y)) x_val.append(xx) y_val.append(yy)# print(np.shape(x_train))# print(np.shape(y_train))# 用于保存验证集误差最小的参数,当验证集误差减少时,立马保存下来checkpointer = ModelCheckpoint(filepath="model_muticlass.w", verbose=0, save_best_only=True, save_weights_only=True) #save_weights_only=Truehistory = LossHistory()earlystop = EarlyStopping(patience=5)model.load_weights('model_muticlass.w')model.fit(np.array(x_train), np.array(y_train), epochs=100, batch_size=3,verbose=1, validation_data=(np.array(x_val), np.array(y_val)), callbacks=[checkpointer, history, earlystop] )model.load_weights('model_muticlass.w')for file in glob.glob(dir+r'\*.tif'): x = cv2.imread(file, cv2.IMREAD_COLOR) y = cv2.imread(dir+r'\mask\\'+os.path.split(file)[1], cv2.IMREAD_COLOR) x = 255 - x # y = 255 - y (xx,yy)=validation_process((x,y)) x_val = [] x_val.append(xx) print(np.shape(x_val)) # xx=np.expand_dims(xx, axis=0) # yy=np.expand_dims(yy, axis=0) predicted_mask_batch = model.predict(np.array(x_val)) print(np.shape(predicted_mask_batch)) predicted_mask = predicted_mask_batch[0].reshape((128,128,3)) plt.imshow(xx[0]) plt.imshow(predicted_mask[:,:,2], alpha=0.6) plt.show()K.clear_session()
keras实现Unet进行字符定位与识别分类
最新推荐文章于 2024-09-05 16:18:34 发布