先 看一下效果 ,虽然 算不上很好
直接上代码:
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import time
import numpy as np
import io
import PIL
from IPython.display import clear_output
import cv2
import sys
sys.path.append("/opt/LIP/examples")
from tensorflow_examples.models.pix2pix import pix2pix
IMG_WIDTH = 128
IMG_WIDTH = 128
IM_PATH='/opt/LIP/images/'
MS_PATH='/opt/LIP/masks/'
OUTPUT_CHANNELS = 20
EPOCHS = 20
BATCH_SIZE=256
def display(display_list):
plt.figure(figsize=(15, 15))
title = ['Input Image', 'True Mask', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
plt.axis('off')
plt.show()
def load_input(image_file):
print(image_file)
#print(str(image_file))
img=tf.io.read_file(image_file)
img=tf.image.decode_jpeg(img, channels=3)
image=tf.image.resize(img, [IMG_WIDTH, IMG_WIDTH])
image=(image / 127.5) - 1#normalizing the images to [-1, 1]
#image=image /255.0
#image=image.reshape()
#image=tf.reshape(image,[1,IMG_WIDTH,IMG_WIDTH,3])
return image
def load_mask(image_file):
img=tf.io.read_file(image_file)
img=tf.image.decode_png(img, channels=1)
image=tf.image.resize(img, [IMG_WIDTH, IMG_WIDTH])
#image -= 1
#image=image.reshape(1,IMG_WIDTH,IMG_WIDTH,3)
#image=tf.reshape(image,[1,IMG_WIDTH,IMG_WIDTH,3])
return image
def load(image_file,mask_file):
_in=load_input(image_file)
_mask=load_mask(mask_file)
return _in,_mask
train_image_path = os.path.join(IM_PATH+'train/')
train_mask_path = os.path.join(MS_PATH+'train/')
train_images = os.listdir(train_image_path)
train_masks = os.listdir(train_mask_path)
train_images.sort()
train_masks.sort()
train_ls_images=[]
train_ls_masks=[]
for i in train_images:
train_ls_images.append(IM_PATH+'train/'+i)
for j in train_masks:
train_ls_masks.append(MS_PATH+'train/'+j)
train_images = tf.constant(train_ls_images)
train_labels = tf.constant(train_ls_masks)
train_data = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_data = train_data.map(load, num_parallel_calls=4)
train_batched_data = train_data.batch(BATCH_SIZE)
val_image_path = os.path.join(IM_PATH+'val/')
val_mask_path = os.path.join(MS_PATH+'val/')
val_images = os.listdir(val_image_path)
val_masks = os.listdir(val_mask_path)
val_images.sort()
val_masks.sort()
val_ls_images=[]
val_ls_masks=[]
for i in val_images:
val_ls_images.append(IM_PATH+'val/'+i)
for j in val_masks:
val_ls_masks.append(MS_PATH+'val/'+j)
val_images = tf.constant(val_ls_images)
val_labels = tf.constant(val_ls_masks)
val_data = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_data = val_data.map(load, num_parallel_calls=4)
val_batched_data = val_data.batch(BATCH_SIZE)
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)
# Use the activations of these layers
layer_names = [
'block_1_expand_relu', # 64x64
'block_3_expand_relu', # 32x32
'block_6_expand_relu', # 16x16
'block_13_expand_relu', # 8x8
'block_16_project', # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]
# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False
up_stack = [
pix2pix.upsample(512, 3), # 4x4 -> 8x8
pix2pix.upsample(256, 3), # 8x8 -> 16x16
pix2pix.upsample(128, 3), # 16x16 -> 32x32
pix2pix.upsample(64, 3), # 32x32 -> 64x64
]
def unet_model(output_channels):
inputs = tf.keras.layers.Input(shape=[128, 128, 3])
x = inputs
# Downsampling through the model
skips = down_stack(x)
x = skips[-1]
skips = reversed(skips[:-1])
for up, skip in zip(up_stack, skips):
x = up(x)
concat = tf.keras.layers.Concatenate()
x = concat([x, skip])
last = tf.keras.layers.Conv2DTranspose(
output_channels, 3, strides=2,
padding='same') #64x64 -> 128x128
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x)
model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
tf.keras.utils.plot_model(model, show_shapes=True)
for image, mask in train_data.take(1):
sample_image, sample_mask = image, mask
display([sample_image, sample_mask])
def create_mask(pred_mask):
pred_mask = tf.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[..., tf.newaxis]
return pred_mask[0]
def show_predictions(dataset=None, num=1):
if dataset:
for image, mask in dataset.take(num):
pred_mask = model.predict(image)
display([image[0], mask[0], create_mask(pred_mask)])
else:
display([sample_image, sample_mask,
create_mask(model.predict(sample_image[tf.newaxis, ...]))])
show_predictions()
class DisplayCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
clear_output(wait=True)
show_predictions()
print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
model_history = model.fit(train_batched_data, epochs=EPOCHS,
#validation_data=val_batched_data,
callbacks=[DisplayCallback()]
)
说明:
执行这个代码之前需要的准备工作:
(1)下载数据集,此次数据集使用的是LIP数据集,不过此次我对数据集的下载好的数据位置进行了调整,方便数据读取
(2)下载tensorflow_examples的文件放置在指定的路径(我是放置在/opt/LIP/examples)
(3)一定要使用tf-2.3版本,以下的版本很容易出错,特别的低于2.0的
最终保存了模型文件 ,大小在60兆左右
本次代码主要参考:https://tensorflow.google.cn/tutorials/images/segmentation
准确率:
效果不是很好,但是也还可以使用了 ,毕竟也就训练了20轮
对于模型文件MobileNetV2,如果代码拉取下载很慢可以提前下好放在这个路径下面就行:/root/.keras/ models/
可以看一下:
因为这个自动下载的路径 就是下载到/root/.keras/ models/这里