基于keras+tensorflow
直接上代码
import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation
from keras.utils.np_utils import to_categorical
from keras.preprocessing.image import img_to_array
from keras.callbacks import ModelCheckpoint
from sklearn.preprocessing import LabelEncoder
from PIL import Image
import matplotlib.pyplot as plt
seed = 7
np.random.seed(seed)
#data_shape = 360*480
img_w = 256
img_h = 256
#有一个为背景
n_label = 21+1
classes = [0. , 1., 2., 3. , 4., 5., 6., 7., 8., 9., 10., 11., 12. ,
13. , 14. , 15. , 16. , 17. , 18. , 19. , 20. , 255.]
labelencoder = LabelEncoder()
labelencoder.fit(classes)
#原keras中的load_img只读取两种格式,gray和RGB,其他类型的图像都会转为RBG进行读取,本处对此进行更改,保留原图格式
def load_img(path, grayscale=False, target_size=None):
img = Image.open(path)
if grayscale:
if img.mode != 'L':
img = img.convert('L')
if target_size:
wh_tuple = (target_size[1], target_size[0])
if img.size != wh_tuple:
img = img.resize(wh_tuple)
return img
train_url = open(r'/media/wmy/document/BigData/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt','r').readlines()
#trainval_url = open(r'/media/wmy/document/BigData/VOCdevkit/VOC2012/ImageSets/Segmentation/trainval.txt','r').readlines()
val_url = open(r'/media/wmy/document/BigData/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt','r').readlines()
train_numb = len(train_url)
valid_numb = len(val_url)
print "the number of train data is",train_numb
print "the number of val data is",valid_numb
filepath ='/media/wmy/document/BigData/VOCdevkit/VOC2012/'
def generateData(batch_size):
train_url = open(r'/media/wmy/document/BigData/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt',
'r').readlines()
while True:
train_data = []
train_label = []
batch = 0
for url in train_url:
batch += 1
img = load_img(filepath + 'JPEGImages/' + url.strip('\n') + '.jpg', target_size=(img_w, img_h))
img = img_to_array(img)
# print img.shape
train_data.append(img)
label = load_img(filepath + 'SegmentationClass/' + url.strip('\n') + '.png', target_size=(img_w, img_h))
label = img_to_array(label).reshape((img_w * img_h,))
# print label.shape
train_label.append(label)
if batch % batch_size==0:
train_data = np.array(train_data)
train_label = np.array(train_label).flatten()
train_label = labelencoder.transform(train_label)
train_label = to_categorical(train_label, num_classes=n_label)
train_label = train_label.reshape((batch_size,img_w * img_h,n_label))
yield (train_data,train_label)
train_data = []
train_label = []
batch = 0
def generateValidData(batch_size):
val_url = open(r'/media/wmy/document/BigData/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt',
'r').readlines()
while True:
valid_data = []
valid_label = []
batch = 0
for url in val_url:
batch += 1
img = load_img(filepath + 'JPEGImages/' + url.strip('\n') + '.jpg', target_size=(img_w, img_h))
img = img_to_array(img)
# print img.shape
valid_data.append(img)
label = load_img(filepath + 'SegmentationClass/' + url.strip('\n') + '.png', target_size=(img_w, img_h))
label = img_to_array(label).reshape((img_w * img_h,))
# print label.shape
valid_label.append(label)
if batch % batch_size==0:
valid_data = np.array(valid_data)
valid_label = np.array(valid_label).flatten()
valid_label = labelencoder.transform(valid_label)
valid_label = to_categorical(valid_label, num_classes=n_label)
valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label))
yield (valid_data,valid_label)
valid_data = []
valid_label = []
batch = 0
def SegNet():
model = Sequential()
#encoder
model.add(Conv2D(64,(3,3),strides=(1,1),input_shape=(img_w,img_h,3),padding='same',activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64,(3,3),strides=(1,1),padding='same',activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
#(128,128)
model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
#(64,64)
model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
#(32,32)
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
#(16,16)
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
#(8,8)
#decoder
model.add(UpSampling2D(size=(2,2)))
#(16,16)
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(UpSampling2D(size=(2, 2)))
#(32,32)
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(UpSampling2D(size=(2, 2)))
#(64,64)
model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(UpSampling2D(size=(2, 2)))
#(128,128)
model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(UpSampling2D(size=(2, 2)))
#(256,256)
model.add(Conv2D(64, (3, 3), strides=(1, 1), input_shape=(img_w, img_h, 3), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(n_label, (1, 1), strides=(1, 1), padding='same'))
model.add(Reshape((n_label,img_w*img_h)))
#axis=1和axis=2互换位置,等同于np.swapaxes(layer,1,2)
model.add(Permute((2,1)))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])
model.summary()
return model
def train():
model = SegNet()
modelcheck = ModelCheckpoint('Segnet_params_1.h5',monitor='val_acc',save_best_only=True,mode='max')
callable = [modelcheck]
model.fit_generator(generator=generateData(5),steps_per_epoch=train_numb,epochs=50,verbose=2,
validation_data=generateValidData(5),validation_steps=valid_numb,callbacks=callable,max_q_size=1)
def predict():
model = SegNet()
model.load_weights('Segnet_params.h5')
while True:
print "please input the test img path:"
test_imgpath = raw_input()
img = load_img(test_imgpath, target_size=(img_w, img_h))
img = img_to_array(img).reshape((1, img_h, img_w, -1))
pred = model.predict_classes(img,verbose=2)
pred = labelencoder.inverse_transform(pred[0])
print np.unique(pred)
pred = pred.reshape((img_h,img_w)).astype(np.uint8)
pred_img = Image.fromarray(pred)
pred_img.save('1.png',format='png')
'''
print pred
plt.subplot(2,1,1)
plt.imshow(img.reshape((img_h,img_w,3)).astype(np.uint8))
plt.subplot(2,1,2)
plt.imshow(pred)
plt.show()
'''
if __name__=='__main__':
train()
predict()