本文实现使用keras实现U-Net2d网络,对VOC2012进行分割
UNet2d
首先是Unet2d网络的搭建
import keras.backend as K
from keras.engine import Input,Model
import keras
from keras.optimizers import Adam
from keras.layers import BatchNormalization,Activation,Conv2D,MaxPooling2D,Conv2DTranspose,UpSampling2D
import metrics as m
from keras.layers.core import Lambda
import numpy as np
def up_and_concate(down_layer, layer):
in_channel = down_layer.get_shape().as_list()[3]
out_channel = in_channel // 2
up = Conv2DTranspose(out_channel,[2,2],strides=[2,2])(down_layer)
print("--------------")
print(str(up.get_shape()))
print(str(layer.get_shape()))
print("--------------")
my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))
concate = my_concat([up, layer])
# must use lambda
#concate=K.concatenate([up, layer], 3)
return concate
def attention_block_2d(x, g, inter_channel):
'''
:param x: x input from down_sampling same layer output (?,x_height,x_width,x_channel)
:param g: gate input from up_sampling layer last output (?,g_height,g_width,g_channel)
g_height,g_width=x_height/2,x_width/2
:return:
'''
print('attention_block:')
# theta_x(?,g_height,g_width,inter_channel)
theta_x = Conv2D(inter_channel, [2, 2], strides=[2, 2])(x)
print(str(theta_x.get_shape()))
# phi_g(?,g_height,g_width,inter_channel)
phi_g = Conv2D(inter_channel, [1, 1], strides=[1, 1])(g)
print(str(phi_g.get_shape()))
# f(?,g_height,g_width,inter_channel)
f = Activation('relu')(keras.layers.add([theta_x, phi_g]))
print(str(f.get_shape()))
# psi_f(?,g_height,g_width,1)
psi_f = Conv2D(1, [1, 1], strides=[1, 1])(f)
print(str(psi_f.get_shape()))
# sigm_psi_f(?,g_height,g_width)
sigm_psi_f = Activation('sigmoid')(psi_f)
print(str(sigm_psi_f.get_shape()))
# rate(?,x_height,x_width)
rate = UpSampling2D(size=[2, 2])(sigm_psi_f)
print(str(rate.get_shape()))
# att_x(?,x_height,x_width,x_channel)
att_x = keras.layers.multiply([x, rate])
print(str(att_x.get_shape()))
print('-----------------')
return att_x
def unet_model_2d_attention(input_shape,n_labels,batch_normalization=False,initial_learning_rate=0.00001,metrics=m.dice_coefficient):
"""
input_shape:without batch_size,(img_height,img_width,img_depth)
metrics:
"""
inputs=Input(input_shape)
down_layer=[]
layer=inputs
#down_layer_1
layer=res_block_v2(layer,64,batch_normalization=batch_normalization)
down_layer.append(layer)
layer=MaxPooling2D(pool_size=[2,2],strides=[2,2])(layer)
print(str(layer.get_shape()))
# down_layer_2
layer = res_block_v2(layer, 128, batch_normalization=batch_normalization)
down_layer.append(layer)
layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
print(str(layer.get_shape()))
# down_layer_3
layer = res_block_v2(layer, 256, batch_normalization=batch_normalization)
down_layer.append(layer)
layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
print(str(layer.get_shape()))
# down_layer_4
layer = res_block_v2(layer, 512, batch_normalization=batch_normalization)
down_layer.append(layer)
layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
print(str(layer.get_shape()))
# bottle_layer
layer = res_block_v2(layer, 1024, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_4
layer = attention_block_2d( down_layer[3],layer,256)
layer = res_block_v2(layer, 512,batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_3
layer = attention_block_2d( down_layer[2],layer,128)
layer = res_block_v2(layer, 256, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_2
layer = attention_block_2d( down_layer[1],layer,64)
layer = res_block_v2(layer, 128, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_1
layer = attention_block_2d( down_layer[0],layer,32)
layer = res_block_v2(layer, 64, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# score_layer
layer = Conv2D(n_labels,[1,1],strides=[1,1])(layer)
print(str(layer.get_shape()))
# softmax
layer = Activation('softmax')(layer)
print(str(layer.get_shape()))
outputs=layer
model=Model(inputs=inputs,outputs=outputs)
metrics=[metrics]
model.compile(optimizer=Adam(lr=initial_learning_rate), loss=m.dice_coefficient_loss, metrics=metrics)
return model
def unet_model_2d(input_shape,n_labels,batch_normalization=False,initial_learning_rate=0.00001,metrics=m.dice_coefficient):
"""
input_shape:without batch_size,(img_height,img_width,img_depth)
metrics:
"""
inputs=Input(input_shape)
down_layer=[]
layer=inputs
#down_layer_1
layer=res_block_v2(layer,64,batch_normalization=batch_normalization)
down_layer.append(layer)
layer=MaxPooling2D(pool_size=[2,2],strides=[2,2])(layer)
print(str(layer.get_shape()))
# down_layer_2
layer = res_block_v2(layer, 128, batch_normalization=batch_normalization)
down_layer.append(layer)
layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
print(str(layer.get_shape()))
# down_layer_3
layer = res_block_v2(layer, 256, batch_normalization=batch_normalization)
down_layer.append(layer)
layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
print(str(layer.get_shape()))
# down_layer_4
layer = res_block_v2(layer, 512, batch_normalization=batch_normalization)
down_layer.append(layer)
layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
print(str(layer.get_shape()))
# bottle_layer
layer = res_block_v2(layer, 1024, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_4
layer = up_and_concate(layer, down_layer[3])
layer = res_block_v2(layer, 512,batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_3
layer = up_and_concate(layer, down_layer[2])
layer = res_block_v2(layer, 256, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_2
layer = up_and_concate(layer, down_layer[1])
layer = res_block_v2(layer, 128, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# up_layer_1
layer = up_and_concate(layer, down_layer[0])
layer = res_block_v2(layer, 64, batch_normalization=batch_normalization)
print(str(layer.get_shape()))
# score_layer
layer = Conv2D(n_labels,[1,1],strides=[1,1])(layer)
print(str(layer.get_shape()))
# softmax
layer = Activation('softmax')(layer)
print(str(layer.get_shape()))
outputs=layer
model=Model(inputs=inputs,outputs=outputs)
metrics=[metrics]
model.compile(optimizer=Adam(lr=initial_learning_rate), loss=m.dice_coefficient_loss, metrics=metrics)
return model
def res_block_v2(input_layer,out_n_filters,batch_normalization=False,kernel_size=[3,3],stride=[1,1],padding='same'):
input_n_filters = input_layer.get_shape().as_list()[3]
print(str(input_layer.get_shape()))
layer=input_layer
for i in range(2):
if batch_normalization:
layer=BatchNormalization()(layer)
layer=Activation('relu')(layer)
layer=Conv2D(out_n_filters,kernel_size,strides=stride,padding=padding)(layer)
if out_n_filters!=input_n_filters:
skip_layer=Conv2D(out_n_filters,[1,1],strides=stride,padding=padding)(input_layer)
else:
skip_layer=input_layer
out_layer=keras.layers.add([layer,skip_layer])
return out_layer
使用Keras中的Model类,首先使用Input(input_shape),注意这里的input_shape是不带batch_size这一维的,在这里就是(img_height,img_width,img_depth)
metrics为评判标准
def up_and_concate(down_layer, layer):
in_channel = down_layer.get_shape().as_list()[3]
out_channel = in_channel // 2
up = Conv2DTranspose(out_channel,[2,2],strides=[2,2])(down_layer)
print("--------------")
print(str(up.get_shape()))
print(str(layer.get_shape()))
print("--------------")
my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))
concate = my_concat([up, layer])
# must use lambda
#concate=K.concatenate([up, layer], 3)
return concate
以上为skip_connection的函数,down_layer是上一层上采样层的输出,layer为同层下采样层的输出
注意这里不能直接用K.concatenate,会报错说使用的tensor不是keras里边的tensor,必须使用Lambda
def res_block_v2(input_layer,out_n_filters,batch_normalization=False,kernel_size=[3,3],stride=[1,1],padding='same'):
input_n_filters = input_layer.get_shape().as_list()[3]
print(str(input_layer.get_shape()))
layer=input_layer
for i in range(2):
if batch_normalization:
layer=BatchNormalization()(layer)
layer=Activation('relu')(layer)
layer=Conv2D(out_n_filters,kernel_size,strides=stride,padding=padding)(layer)
if out_n_filters!=input_n_filters:
skip_layer=Conv2D(out_n_filters,[1,1],strides=stride,padding=padding)(input_layer)
else:
skip_layer=input_layer
out_layer=keras.layers.add([layer,skip_layer])
return out_layer
VOC读取
接下来是voc数据集的读取,voc数据集的目录如下:
Annotations为存放每个图片的描述文件(.xml),类别,检测框什么的,在分割上用不到
ImageSets为存放了各个任务的trian val所需图片的名称
JPEGImages为存放所有图片(.jpg)
SegmentationClass是用于语义分割的标签(.png)
SegmentationObject是用于实例分割的标签(.png)
在本文只用了ImageSets中的Segmentation和SegmentionClass中的文件
文件的读取如下:
import tensorflow as tf
from PIL import Image
import PIL
import scipy.misc as misc
import numpy as np
def make_one_hot(x,n):
'''
print(x.shape)
one_hot=np.zeros([x.shape[0],x.shape[1],n])
print(one_hot.shape)
for i in range(n):
#print(x==i)
print(one_hot[x==i])
one_hot[x==i][i]=1
'''
one_hot = np.zeros([x.shape[0], x.shape[1], n])
for i in range(x.shape[0]):
for j in range(x.shape[1]):
one_hot[i,j,x[i,j]]=1
return one_hot
class voc_reader:
def __init__(self,resize_width,resize_height,train_batch_size,val_batch_size):
self.train_file_name_list=self.load_file_name_list(file_path="D:\\pyproject\\data\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\ImageSets\\Segmentation\\train.txt")
self.val_file_name_list=self.load_file_name_list(file_path="D:\\pyproject\\data\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\ImageSets\\Segmentation\\val.txt")
self.row_file_path="D:\\pyproject\\data\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\JPEGImages\\"
self.label_file_path="D:\\pyproject\\data\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\SegmentationClass\\"
self.train_batch_index=0
self.val_batch_index=0
self.resize_width=resize_width
self.resize_height=resize_height
self.n_train_file=len(self.train_file_name_list)
self.n_val_file=len(self.val_file_name_list)
self.train_batch_size=train_batch_size
self.val_batch_size=val_batch_size
print(self.n_train_file)
print(self.n_val_file)
self.n_train_steps_per_epoch=self.n_train_file//self.train_batch_size
self.n_val_steps_per_epoch=self.n_val_file//self.val_batch_size
def load_file_name_list(self,file_path):
file_name_list=[]
with open(file_path, 'r') as file_to_read:
while True:
lines = file_to_read.readline().strip() # 整行读取数据
if not lines:
break
pass
file_name_list.append(lines)
pass
return file_name_list
def next_train_batch(self):
train_imgs=np.zeros((self.train_batch_size,self.resize_height,self.resize_width,3))
train_labels=np.zeros([self.train_batch_size,self.resize_height,self.resize_width,21])
if self.train_batch_index>=self.n_train_steps_per_epoch:
print("next epoch")
self.train_batch_index=0
print('------------------')
print(self.train_batch_index)
for i in range(self.train_batch_size):
index=self.train_batch_size*self.train_batch_index+i
print('index'+str(index))
img = Image.open(self.row_file_path+self.train_file_name_list[index]+'.jpg')
img=img.resize((self.resize_height,self.resize_width),Image.NEAREST)
img=np.array(img)
train_imgs[i]=img
#print(img.shape)
np.set_printoptions(threshold=np.inf)
label=Image.open(self.label_file_path+self.train_file_name_list[index]+'.png')
label=label.resize((self.resize_height,self.resize_width),Image.NEAREST)
label=np.array(label, dtype=np.int32)
#print(label[label>20])
#label[label == 255] = -1
label[label==255]=0
#print(label)
#print(label.shape)
one_hot_label=make_one_hot(label,21)
train_labels[i]=one_hot_label
#print(one_hot_label.shape)
#print(label)
#print(label)
self.train_batch_index+=1
print('------------------')
return train_imgs,train_labels
def next_val_batch(self):
val_imgs = np.zeros((self.val_batch_size, self.resize_height, self.resize_width, 3))
val_labels = np.zeros([self.val_batch_size, self.resize_height, self.resize_width, 21])
if self.val_batch_index>=self.n_val_steps_per_epoch:
print("next epoch")
self.val_batch_index=0
print('------------------')
print(self.val_batch_index)
for i in range(self.val_batch_size):
index=self.val_batch_size*self.val_batch_index+i
print('index'+str(index))
img=Image.open(self.row_file_path+self.val_file_name_list[index]+'.jpg')
img = img.resize((self.resize_height, self.resize_width), Image.NEAREST)
img = np.array(img)
val_imgs[i]=img
label = Image.open(self.label_file_path + self.val_file_name_list[index] + '.png')
label = label.resize((self.resize_height, self.resize_width), Image.NEAREST)
label = np.array(label, dtype=np.int32)
# print(label[label>20])
# label[label == 255] = -1
label[label == 255] = 0
# print(label)
# print(label.shape)
one_hot_label = make_one_hot(label, 21)
val_labels[i]=one_hot_label
print('------------------')
self.val_batch_index+=1
return val_imgs,val_labels
Train
先构造train和val的generator,用于等会的fit_generator
def train_generator_data(batch_size,voc_reader):
while True:
x,y=voc_reader.next_train_batch(batch_size)
yield (x,y)
def val_generator_data(batch_size,voc_reader):
while True:
x,y=voc_reader.next_val_batch(batch_size)
yield (x,y)
随后定义callback
def get_callbacks(model_file,initial_learning_rate=0.0001,learning_rate_drop=0.5,learning_rate_epochs=None,
learning_rate_patience=50,logging_file="training.log",verbosity=1,early_stopping_patience=None):
callbacks=list()
callbacks.append(ModelCheckpoint(model_file,save_best_only=True))
callbacks.append(CSVLogger(logging_file,append=True))
callbacks.append(TensorBoard())
if learning_rate_epochs:
callbacks.append(LearningRateScheduler(partial(step_decay, initial_lrate=initial_learning_rate,
drop=learning_rate_drop, epochs_drop=learning_rate_epochs)))
else:
callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience,
verbose=verbosity))
if early_stopping_patience:
callbacks.append(EarlyStopping(verbose=verbosity, patience=early_stopping_patience))
return callbacks
加载已经训练过的模型
def load_old_model(model_file):
print("Loading pre-trained model")
custom_objects = {'dice_coefficient_loss': dice_coefficient_loss, 'dice_coefficient': dice_coefficient}
try:
from keras_contrib.layers import InstanceNormalization
custom_objects["InstanceNormalization"] = InstanceNormalization
except ImportError:
pass
try:
return load_model(model_file, custom_objects=custom_objects)
except ValueError as error:
if 'InstanceNormalization' in str(error):
raise ValueError(str(error) + "\n\nPlease install keras-contrib to use InstanceNormalization:\n"
"'pip install git+https://www.github.com/keras-team/keras-contrib.git'")
else:
raise error
train_model函数
def train_model(model, model_file, training_generator, validation_generator, steps_per_epoch, validation_steps,
initial_learning_rate=0.001, learning_rate_drop=0.5, learning_rate_epochs=None, n_epochs=500,
learning_rate_patience=20, early_stopping_patience=None):
"""
Train a Keras model.
:param early_stopping_patience: If set, training will end early if the validation loss does not improve after the
specified number of epochs.
:param learning_rate_patience: If learning_rate_epochs is not set, the learning rate will decrease if the validation
loss does not improve after the specified number of epochs. (default is 20)
:param model: Keras model that will be trained.
:param model_file: Where to save the Keras model.
:param training_generator: Generator that iterates through the training data.
:param validation_generator: Generator that iterates through the validation data.
:param steps_per_epoch: Number of batches that the training generator will provide during a given epoch.
:param validation_steps: Number of batches that the validation generator will provide during a given epoch.
:param initial_learning_rate: Learning rate at the beginning of training.
:param learning_rate_drop: How much at which to the learning rate will decay.
:param learning_rate_epochs: Number of epochs after which the learning rate will drop.
:param n_epochs: Total number of epochs to train the model.
:return:
"""
model.fit_generator(generator=training_generator,
steps_per_epoch=steps_per_epoch,
epochs=n_epochs,
validation_data=validation_generator,
validation_steps=validation_steps,
callbacks=get_callbacks(model_file,
initial_learning_rate=initial_learning_rate,
learning_rate_drop=learning_rate_drop,
learning_rate_epochs=learning_rate_epochs,
learning_rate_patience=learning_rate_patience,
early_stopping_patience=early_stopping_patience))