图像分隔就是给图像中的每个像素分配一个标签,图像分隔的任务是训练一个神经网络来输出该图像对每一个像素的掩码。
1. 导入所需的库
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_examples.models.pix2pix.pix2pix as pix2pix
import matplotlib.pyplot as plt
from IPython.display import clear_output
for i in [tf,tfds]:
print(i.__name__,": ",i.__version__,sep="")
输出:
tensorflow: 2.2.0
tensorflow_datasets: 3.1.0
2. 下载数据集
Oxford-IIIT Pets dataset数据集由37类宠物图像组成,每类大约200张图像每张图像在比例、姿势、亮度等方面有很大差别。该数据集由图像、图像所对应的标签、以及对像素逐一标记的掩码组成。每个像素属于以下三类别之一:
- 像素是宠物的一部分
- 像素是宠物的轮廓
- 以上都不是
dataset, info = tfds.load("oxford_iiit_pet:3.*.*",with_info=True) # 3.0.0版本以上才有图像分隔的标签
输出:
Downloading and preparing dataset oxford_iiit_pet/3.2.0 (download: 773.52 MiB, generated: 774.69 MiB, total: 1.51 GiB) to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0.incomplete24GDUG\oxford_iiit_pet-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=3680.0), HTML(value='')))
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0.incomplete24GDUG\oxford_iiit_pet-test.tfrecord
HBox(children=(FloatProgress(value=0.0, max=3669.0), HTML(value='')))
Dataset oxford_iiit_pet downloaded and prepared to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0. Subsequent calls will reuse this data.
3. 数据预处理
对下载的数据做以下预处理:
- 图像翻转
- 图像像素值归一化到[0,1]
- 分隔掩码由{1, 2, 3}变为{0, 1, 2}
def normalize(input_image, input_mask):
input_image = tf.cast(input_image, tf.float32)/255.0 # 像素值归一化
input_mask -= 1
return input_image, input_mask
@tf.function
def load_image_train(datapoint):
input_image = tf.image.resize(datapoint["image"],(128,128)) # 图像resize
input_mask = tf.image.resize(datapoint["segmentation_mask"],(128,128))
if tf.random.uniform(()) > 0.5: # 随机按50%进行水平翻转
input_image = tf.image.flip_left_right(input_image)
input_mask = tf.image.flip_left_right(input_mask)
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask
def load_image_test(datapoint):
input_image = tf.image.resize(datapoint["image"],(128,128))
input_mask = tf.image.resize(datapoint["segmentation_mask"],(128,128))
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask
train_length = info.splits["train"].num_examples # 按数据集原始的划分数目和比例
batch_size = 64
buffer_size = 1000
steps_per_epoch = train_length//batch_size
train = dataset["train"].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset["test"].map(load_image_test)
train_dataset = train.cache().shuffle(buffer_size).batch(batch_size).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(batch_size)
# 查看图像格式
def display(display_list):
plt.figure(figsize=(10,10))
title = ["Input Image","True Mask","Predicted Mask"]
for i in range(len(display_list)):
plt.subplot(1,len(display_list),i+1)
plt.title(title[i])
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
plt.axis("off")
plt.tight_layout()
plt.show()
for image, mask in train.take(5):
sample_image, sample_mask = image, mask
display([sample_image, sample_mask])
输出:
4. 构建模型
这里使用的是一个修改版的U-Net。U-Net由一个编码器(降采样器)和解码器(上采样器)组成。为了学习到更鲁棒的特征和减少训练参数的数量,这里使用了预训练的MobileNet V2模型,使用其中间的输出结果。而解码器使用TensorFlow Examples中的pix2pix。
output_channels = 3 # 输出通道是3的原因是每个像素有三种可能的标签。详见第二部分
base_model = tf.keras.applications.MobileNetV2(input_shape=[128,128,3],include_top=False)
layer_names = ["block_1_expand_relu",
"block_3_expand_relu",
"block_6_expand_relu",
"block_13_expand_relu",
"block_16_project"]
layers = [base_model.get_layer(name).output for name in layer_names]
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False
输出:
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5
9412608/9406464 [==============================] - 131s 14us/step
# 解码器就是一系列上采样块
up_stack = [pix2pix.upsample(512,3),
pix2pix.upsample(256,3),
pix2pix.upsample(128,3),
pix2pix.upsample(64,3),]
def unet_model(output_channels):
inputs = tf.keras.layers.Input(shape=[128,128,3])
x = inputs
skips = down_stack(x) # 降采样
x = skips[-1]
skips = reversed(skips[:-1])
for up, skip in zip(up_stack, skips):
x = up(x)
concat = tf.keras.layers.Concatenate()
x = concat([x,skip])
last = tf.keras.layers.Conv2DTranspose(output_channels,3,
strides=2,
padding="same")
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x)
5. 训练模型
model = unet_model(output_channels)
model.compile(optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"])
model.summary()
输出:
Model: "model_4"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_5 (InputLayer) [(None, 128, 128, 3) 0
__________________________________________________________________________________________________
model_2 (Model) [(None, 64, 64, 96), 1841984 input_5[0][0]
__________________________________________________________________________________________________
sequential_4 (Sequential) (None, 8, 8, 512) 1476608 model_2[2][4]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 8, 8, 1088) 0 sequential_4[1][0]
model_2[2][3]
__________________________________________________________________________________________________
sequential_5 (Sequential) (None, 16, 16, 256) 2507776 concatenate_2[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 16, 16, 448) 0 sequential_5[0][0]
model_2[2][2]
__________________________________________________________________________________________________
sequential_6 (Sequential) (None, 32, 32, 128) 516608 concatenate_3[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 32, 32, 272) 0 sequential_6[0][0]
model_2[2][1]
__________________________________________________________________________________________________
sequential_7 (Sequential) (None, 64, 64, 64) 156928 concatenate_4[0][0]
__________________________________________________________________________________________________
concatenate_5 (Concatenate) (None, 64, 64, 160) 0 sequential_7[0][0]
model_2[2][0]
__________________________________________________________________________________________________
conv2d_transpose_10 (Conv2DTran (None, 128, 128, 3) 4323 concatenate_5[0][0]
==================================================================================================
Total params: 6,504,227
Trainable params: 4,660,323
Non-trainable params: 1,843,904
__________________________________________________________________________________________________
def create_mask(pred_mask):
pred_mask = tf.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[...,tf.newaxis]
return pred_mask[0]
def show_predictions(dataset=None, num=1):
if dataset:
for image, mask in dataset.take(num):
pred_mask = model.predict(image)
display([image[0],mask[0],create_mask(pred_mask)])
else:
display([sample_image,sample_mask,
create_mask(model.predict(sample_image[tf.newaxis,...]))])
show_predictions()
输出:
class DisplayCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
#clear_output(wait=True)
show_predictions()
print("\nSample Prediction after epoch {}\n".format(epoch+1))
epochs = 20
val_subsplits = 5
validation_steps = info.splits["test"].num_examples//batch_size//val_subsplits
model_history = model.fit(train_dataset, epochs=epochs,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_data=test_dataset,
callbacks=[DisplayCallback()])
输出:
Epoch 1/20
57/57 [==============================] - ETA: 0s - loss: 0.2981 - accuracy: 0.8778
Sample Prediction after epoch 1
57/57 [==============================] - 11s 190ms/step - loss: 0.2981 - accuracy: 0.8778 - val_loss: 0.3067 - val_accuracy: 0.8803
Epoch 2/20
57/57 [==============================] - ETA: 0s - loss: 0.2775 - accuracy: 0.8850
Sample Prediction after epoch 2
57/57 [==============================] - 11s 191ms/step - loss: 0.2775 - accuracy: 0.8850 - val_loss: 0.2926 - val_accuracy: 0.8799
Epoch 3/20
57/57 [==============================] - ETA: 0s - loss: 0.2626 - accuracy: 0.8903
Sample Prediction after epoch 3
57/57 [==============================] - 11s 194ms/step - loss: 0.2626 - accuracy: 0.8903 - val_loss: 0.2793 - val_accuracy: 0.8863
Epoch 4/20
57/57 [==============================] - ETA: 0s - loss: 0.2512 - accuracy: 0.8941
Sample Prediction after epoch 4
57/57 [==============================] - 11s 193ms/step - loss: 0.2512 - accuracy: 0.8941 - val_loss: 0.2889 - val_accuracy: 0.8815
Epoch 5/20
57/57 [==============================] - ETA: 0s - loss: 0.2485 - accuracy: 0.8950
Sample Prediction after epoch 5
57/57 [==============================] - 11s 191ms/step - loss: 0.2485 - accuracy: 0.8950 - val_loss: 0.2874 - val_accuracy: 0.8852
Epoch 6/20
57/57 [==============================] - ETA: 0s - loss: 0.2366 - accuracy: 0.8996
Sample Prediction after epoch 6
57/57 [==============================] - 11s 191ms/step - loss: 0.2366 - accuracy: 0.8996 - val_loss: 0.2734 - val_accuracy: 0.8900
Epoch 7/20
57/57 [==============================] - ETA: 0s - loss: 0.2294 - accuracy: 0.9019
Sample Prediction after epoch 7
57/57 [==============================] - 11s 191ms/step - loss: 0.2294 - accuracy: 0.9019 - val_loss: 0.2719 - val_accuracy: 0.8898
Epoch 8/20
57/57 [==============================] - ETA: 0s - loss: 0.2208 - accuracy: 0.9051
Sample Prediction after epoch 8
57/57 [==============================] - 11s 194ms/step - loss: 0.2208 - accuracy: 0.9051 - val_loss: 0.2778 - val_accuracy: 0.8888
Epoch 9/20
57/57 [==============================] - ETA: 0s - loss: 0.2146 - accuracy: 0.9072
Sample Prediction after epoch 9
57/57 [==============================] - 11s 192ms/step - loss: 0.2146 - accuracy: 0.9072 - val_loss: 0.2747 - val_accuracy: 0.8878
Epoch 10/20
57/57 [==============================] - ETA: 0s - loss: 0.2037 - accuracy: 0.9114
Sample Prediction after epoch 10
57/57 [==============================] - 11s 192ms/step - loss: 0.2037 - accuracy: 0.9114 - val_loss: 0.2805 - val_accuracy: 0.8876
Epoch 11/20
57/57 [==============================] - ETA: 0s - loss: 0.1947 - accuracy: 0.9148
Sample Prediction after epoch 11
57/57 [==============================] - 11s 190ms/step - loss: 0.1947 - accuracy: 0.9148 - val_loss: 0.2765 - val_accuracy: 0.8902
Epoch 12/20
57/57 [==============================] - ETA: 0s - loss: 0.1894 - accuracy: 0.9167
Sample Prediction after epoch 12
57/57 [==============================] - 11s 191ms/step - loss: 0.1894 - accuracy: 0.9167 - val_loss: 0.2732 - val_accuracy: 0.8928
Epoch 13/20
57/57 [==============================] - ETA: 0s - loss: 0.1811 - accuracy: 0.9200
Sample Prediction after epoch 13
57/57 [==============================] - 11s 191ms/step - loss: 0.1811 - accuracy: 0.9200 - val_loss: 0.2983 - val_accuracy: 0.8852
Epoch 14/20
57/57 [==============================] - ETA: 0s - loss: 0.1741 - accuracy: 0.9226
Sample Prediction after epoch 14
57/57 [==============================] - 11s 190ms/step - loss: 0.1741 - accuracy: 0.9226 - val_loss: 0.2878 - val_accuracy: 0.8915
Epoch 15/20
57/57 [==============================] - ETA: 0s - loss: 0.1653 - accuracy: 0.9261
Sample Prediction after epoch 15
57/57 [==============================] - 11s 191ms/step - loss: 0.1653 - accuracy: 0.9261 - val_loss: 0.2972 - val_accuracy: 0.8923
Epoch 16/20
57/57 [==============================] - ETA: 0s - loss: 0.1565 - accuracy: 0.9298
Sample Prediction after epoch 16
57/57 [==============================] - 11s 191ms/step - loss: 0.1565 - accuracy: 0.9298 - val_loss: 0.3054 - val_accuracy: 0.8890
Epoch 17/20
57/57 [==============================] - ETA: 0s - loss: 0.1533 - accuracy: 0.9308
Sample Prediction after epoch 17
57/57 [==============================] - 11s 191ms/step - loss: 0.1533 - accuracy: 0.9308 - val_loss: 0.2971 - val_accuracy: 0.8921
Epoch 18/20
57/57 [==============================] - ETA: 0s - loss: 0.1443 - accuracy: 0.9347
Sample Prediction after epoch 18
57/57 [==============================] - 11s 191ms/step - loss: 0.1443 - accuracy: 0.9347 - val_loss: 0.3004 - val_accuracy: 0.8923
Epoch 19/20
57/57 [==============================] - ETA: 0s - loss: 0.1362 - accuracy: 0.9380
Sample Prediction after epoch 19
57/57 [==============================] - 11s 191ms/step - loss: 0.1362 - accuracy: 0.9380 - val_loss: 0.3153 - val_accuracy: 0.8921
Epoch 20/20
57/57 [==============================] - ETA: 0s - loss: 0.1350 - accuracy: 0.9386
Sample Prediction after epoch 20
57/57 [==============================] - 11s 197ms/step - loss: 0.1350 - accuracy: 0.9386 - val_loss: 0.3201 - val_accuracy: 0.8906
6. 结果可视化
accuracy = model_history.history["accuracy"]
val_accuracy = model_history.history["val_accuracy"]
loss = model_history.history["loss"]
val_loss = model_history.history["val_loss"]
epochs = range(20)
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.plot(epochs, accuracy, "r", label="Train Accuracy")
plt.plot(epochs, val_accuracy, "b", label="Validation Accuracy")
plt.title("Training and Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.ylim([0.85,1])
plt.legend()
plt.subplot(1,2,2)
plt.plot(epochs, loss, "r", label="Training Loss")
plt.plot(epochs, val_loss, "b", label="Validation Loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss Value")
plt.ylim([0,0.5])
plt.legend()
plt.tight_layout()
plt.show()
输出:
7. 利用训练的模型进行预测
show_predictions(test_dataset,10)
输出:
可以看出,预测准确率还挺高的,基本能预测出宠物轮廓。