TensorFlow2利用Oxford-IIIT Pets dataset数据集(MobileNetV2预训练模型和pix2pix)完成图像分隔任务

图像分隔就是给图像中的每个像素分配一个标签,图像分隔的任务是训练一个神经网络来输出该图像对每一个像素的掩码。

1. 导入所需的库

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_examples.models.pix2pix.pix2pix as pix2pix
import matplotlib.pyplot as plt
from IPython.display import clear_output

for i in [tf,tfds]:
    print(i.__name__,": ",i.__version__,sep="")

输出:

tensorflow: 2.2.0
tensorflow_datasets: 3.1.0

2. 下载数据集

Oxford-IIIT Pets dataset数据集由37类宠物图像组成,每类大约200张图像每张图像在比例、姿势、亮度等方面有很大差别。该数据集由图像、图像所对应的标签、以及对像素逐一标记的掩码组成。每个像素属于以下三类别之一:

  1. 像素是宠物的一部分
  2. 像素是宠物的轮廓
  3. 以上都不是
dataset, info = tfds.load("oxford_iiit_pet:3.*.*",with_info=True) # 3.0.0版本以上才有图像分隔的标签

输出:

Downloading and preparing dataset oxford_iiit_pet/3.2.0 (download: 773.52 MiB, generated: 774.69 MiB, total: 1.51 GiB) to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0.incomplete24GDUG\oxford_iiit_pet-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=3680.0), HTML(value='')))
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0.incomplete24GDUG\oxford_iiit_pet-test.tfrecord
HBox(children=(FloatProgress(value=0.0, max=3669.0), HTML(value='')))
Dataset oxford_iiit_pet downloaded and prepared to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0. Subsequent calls will reuse this data.

3. 数据预处理

对下载的数据做以下预处理:

  1. 图像翻转
  2. 图像像素值归一化到[0,1]
  3. 分隔掩码由{1, 2, 3}变为{0, 1, 2}
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32)/255.0 # 像素值归一化
    input_mask -= 1
    return input_image, input_mask

@tf.function
def load_image_train(datapoint):
    input_image = tf.image.resize(datapoint["image"],(128,128)) # 图像resize
    input_mask = tf.image.resize(datapoint["segmentation_mask"],(128,128))
    
    if tf.random.uniform(()) > 0.5:  # 随机按50%进行水平翻转
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

def load_image_test(datapoint):
    input_image = tf.image.resize(datapoint["image"],(128,128))
    input_mask = tf.image.resize(datapoint["segmentation_mask"],(128,128))
    
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask
train_length = info.splits["train"].num_examples # 按数据集原始的划分数目和比例
batch_size = 64
buffer_size = 1000
steps_per_epoch = train_length//batch_size

train = dataset["train"].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset["test"].map(load_image_test)

train_dataset = train.cache().shuffle(buffer_size).batch(batch_size).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(batch_size)
# 查看图像格式
def display(display_list):
    plt.figure(figsize=(10,10))
    title = ["Input Image","True Mask","Predicted Mask"]
    
    for i in range(len(display_list)):
        plt.subplot(1,len(display_list),i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis("off")
    plt.tight_layout()
    plt.show()

for image, mask in train.take(5):
    sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

输出:

4. 构建模型

这里使用的是一个修改版的U-Net。U-Net由一个编码器(降采样器)和解码器(上采样器)组成。为了学习到更鲁棒的特征和减少训练参数的数量,这里使用了预训练的MobileNet V2模型,使用其中间的输出结果。而解码器使用TensorFlow Examples中的pix2pix。

output_channels = 3 # 输出通道是3的原因是每个像素有三种可能的标签。详见第二部分

base_model = tf.keras.applications.MobileNetV2(input_shape=[128,128,3],include_top=False)

layer_names = ["block_1_expand_relu",
              "block_3_expand_relu",
              "block_6_expand_relu",
              "block_13_expand_relu",
              "block_16_project"]
layers = [base_model.get_layer(name).output for name in layer_names]

down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False

输出:

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5
9412608/9406464 [==============================] - 131s 14us/step
# 解码器就是一系列上采样块
up_stack = [pix2pix.upsample(512,3),
           pix2pix.upsample(256,3),
           pix2pix.upsample(128,3),
           pix2pix.upsample(64,3),]

def unet_model(output_channels):  
    inputs = tf.keras.layers.Input(shape=[128,128,3])
    x = inputs
    
    skips = down_stack(x)  # 降采样
    x = skips[-1]
    skips = reversed(skips[:-1])
    
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x,skip])
        
    last = tf.keras.layers.Conv2DTranspose(output_channels,3,
                                           strides=2,
                                           padding="same")
    x = last(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

5. 训练模型

model = unet_model(output_channels)
model.compile(optimizer="adam",
             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=["accuracy"])

model.summary()

输出:

Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_5 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
model_2 (Model)                 [(None, 64, 64, 96), 1841984     input_5[0][0]                    
__________________________________________________________________________________________________
sequential_4 (Sequential)       (None, 8, 8, 512)    1476608     model_2[2][4]                    
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 8, 8, 1088)   0           sequential_4[1][0]               
                                                                 model_2[2][3]                    
__________________________________________________________________________________________________
sequential_5 (Sequential)       (None, 16, 16, 256)  2507776     concatenate_2[0][0]              
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 16, 16, 448)  0           sequential_5[0][0]               
                                                                 model_2[2][2]                    
__________________________________________________________________________________________________
sequential_6 (Sequential)       (None, 32, 32, 128)  516608      concatenate_3[0][0]              
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 32, 32, 272)  0           sequential_6[0][0]               
                                                                 model_2[2][1]                    
__________________________________________________________________________________________________
sequential_7 (Sequential)       (None, 64, 64, 64)   156928      concatenate_4[0][0]              
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 64, 64, 160)  0           sequential_7[0][0]               
                                                                 model_2[2][0]                    
__________________________________________________________________________________________________
conv2d_transpose_10 (Conv2DTran (None, 128, 128, 3)  4323        concatenate_5[0][0]              
==================================================================================================
Total params: 6,504,227
Trainable params: 4,660,323
Non-trainable params: 1,843,904
__________________________________________________________________________________________________
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[...,tf.newaxis]
    return pred_mask[0]

def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0],mask[0],create_mask(pred_mask)])
    else:
        display([sample_image,sample_mask,
                create_mask(model.predict(sample_image[tf.newaxis,...]))])

show_predictions()

输出:

class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        #clear_output(wait=True)
        show_predictions()
        print("\nSample Prediction after epoch {}\n".format(epoch+1))

epochs = 20
val_subsplits = 5
validation_steps = info.splits["test"].num_examples//batch_size//val_subsplits

model_history = model.fit(train_dataset, epochs=epochs,
                          steps_per_epoch=steps_per_epoch,
                          validation_steps=validation_steps,
                          validation_data=test_dataset,
                          callbacks=[DisplayCallback()])

输出:

Epoch 1/20
57/57 [==============================] - ETA: 0s - loss: 0.2981 - accuracy: 0.8778

Sample Prediction after epoch 1

57/57 [==============================] - 11s 190ms/step - loss: 0.2981 - accuracy: 0.8778 - val_loss: 0.3067 - val_accuracy: 0.8803
Epoch 2/20
57/57 [==============================] - ETA: 0s - loss: 0.2775 - accuracy: 0.8850

Sample Prediction after epoch 2

57/57 [==============================] - 11s 191ms/step - loss: 0.2775 - accuracy: 0.8850 - val_loss: 0.2926 - val_accuracy: 0.8799
Epoch 3/20
57/57 [==============================] - ETA: 0s - loss: 0.2626 - accuracy: 0.8903

Sample Prediction after epoch 3

57/57 [==============================] - 11s 194ms/step - loss: 0.2626 - accuracy: 0.8903 - val_loss: 0.2793 - val_accuracy: 0.8863
Epoch 4/20
57/57 [==============================] - ETA: 0s - loss: 0.2512 - accuracy: 0.8941

Sample Prediction after epoch 4

57/57 [==============================] - 11s 193ms/step - loss: 0.2512 - accuracy: 0.8941 - val_loss: 0.2889 - val_accuracy: 0.8815
Epoch 5/20
57/57 [==============================] - ETA: 0s - loss: 0.2485 - accuracy: 0.8950

Sample Prediction after epoch 5

57/57 [==============================] - 11s 191ms/step - loss: 0.2485 - accuracy: 0.8950 - val_loss: 0.2874 - val_accuracy: 0.8852
Epoch 6/20
57/57 [==============================] - ETA: 0s - loss: 0.2366 - accuracy: 0.8996

Sample Prediction after epoch 6

57/57 [==============================] - 11s 191ms/step - loss: 0.2366 - accuracy: 0.8996 - val_loss: 0.2734 - val_accuracy: 0.8900
Epoch 7/20
57/57 [==============================] - ETA: 0s - loss: 0.2294 - accuracy: 0.9019

Sample Prediction after epoch 7

57/57 [==============================] - 11s 191ms/step - loss: 0.2294 - accuracy: 0.9019 - val_loss: 0.2719 - val_accuracy: 0.8898
Epoch 8/20
57/57 [==============================] - ETA: 0s - loss: 0.2208 - accuracy: 0.9051

Sample Prediction after epoch 8

57/57 [==============================] - 11s 194ms/step - loss: 0.2208 - accuracy: 0.9051 - val_loss: 0.2778 - val_accuracy: 0.8888
Epoch 9/20
57/57 [==============================] - ETA: 0s - loss: 0.2146 - accuracy: 0.9072

Sample Prediction after epoch 9

57/57 [==============================] - 11s 192ms/step - loss: 0.2146 - accuracy: 0.9072 - val_loss: 0.2747 - val_accuracy: 0.8878
Epoch 10/20
57/57 [==============================] - ETA: 0s - loss: 0.2037 - accuracy: 0.9114

Sample Prediction after epoch 10

57/57 [==============================] - 11s 192ms/step - loss: 0.2037 - accuracy: 0.9114 - val_loss: 0.2805 - val_accuracy: 0.8876
Epoch 11/20
57/57 [==============================] - ETA: 0s - loss: 0.1947 - accuracy: 0.9148

Sample Prediction after epoch 11

57/57 [==============================] - 11s 190ms/step - loss: 0.1947 - accuracy: 0.9148 - val_loss: 0.2765 - val_accuracy: 0.8902
Epoch 12/20
57/57 [==============================] - ETA: 0s - loss: 0.1894 - accuracy: 0.9167

Sample Prediction after epoch 12

57/57 [==============================] - 11s 191ms/step - loss: 0.1894 - accuracy: 0.9167 - val_loss: 0.2732 - val_accuracy: 0.8928
Epoch 13/20
57/57 [==============================] - ETA: 0s - loss: 0.1811 - accuracy: 0.9200

Sample Prediction after epoch 13

57/57 [==============================] - 11s 191ms/step - loss: 0.1811 - accuracy: 0.9200 - val_loss: 0.2983 - val_accuracy: 0.8852
Epoch 14/20
57/57 [==============================] - ETA: 0s - loss: 0.1741 - accuracy: 0.9226

Sample Prediction after epoch 14

57/57 [==============================] - 11s 190ms/step - loss: 0.1741 - accuracy: 0.9226 - val_loss: 0.2878 - val_accuracy: 0.8915
Epoch 15/20
57/57 [==============================] - ETA: 0s - loss: 0.1653 - accuracy: 0.9261

Sample Prediction after epoch 15

57/57 [==============================] - 11s 191ms/step - loss: 0.1653 - accuracy: 0.9261 - val_loss: 0.2972 - val_accuracy: 0.8923
Epoch 16/20
57/57 [==============================] - ETA: 0s - loss: 0.1565 - accuracy: 0.9298

Sample Prediction after epoch 16

57/57 [==============================] - 11s 191ms/step - loss: 0.1565 - accuracy: 0.9298 - val_loss: 0.3054 - val_accuracy: 0.8890
Epoch 17/20
57/57 [==============================] - ETA: 0s - loss: 0.1533 - accuracy: 0.9308

Sample Prediction after epoch 17

57/57 [==============================] - 11s 191ms/step - loss: 0.1533 - accuracy: 0.9308 - val_loss: 0.2971 - val_accuracy: 0.8921
Epoch 18/20
57/57 [==============================] - ETA: 0s - loss: 0.1443 - accuracy: 0.9347

Sample Prediction after epoch 18

57/57 [==============================] - 11s 191ms/step - loss: 0.1443 - accuracy: 0.9347 - val_loss: 0.3004 - val_accuracy: 0.8923
Epoch 19/20
57/57 [==============================] - ETA: 0s - loss: 0.1362 - accuracy: 0.9380

Sample Prediction after epoch 19

57/57 [==============================] - 11s 191ms/step - loss: 0.1362 - accuracy: 0.9380 - val_loss: 0.3153 - val_accuracy: 0.8921
Epoch 20/20
57/57 [==============================] - ETA: 0s - loss: 0.1350 - accuracy: 0.9386

Sample Prediction after epoch 20

57/57 [==============================] - 11s 197ms/step - loss: 0.1350 - accuracy: 0.9386 - val_loss: 0.3201 - val_accuracy: 0.8906

6. 结果可视化

accuracy = model_history.history["accuracy"]
val_accuracy = model_history.history["val_accuracy"]
loss = model_history.history["loss"]
val_loss = model_history.history["val_loss"]

epochs = range(20)
plt.figure(figsize=(10,5))

plt.subplot(1,2,1)
plt.plot(epochs, accuracy, "r", label="Train Accuracy")
plt.plot(epochs, val_accuracy, "b", label="Validation Accuracy")
plt.title("Training and Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.ylim([0.85,1])
plt.legend()

plt.subplot(1,2,2)
plt.plot(epochs, loss, "r", label="Training Loss")
plt.plot(epochs, val_loss, "b", label="Validation Loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss Value")
plt.ylim([0,0.5])
plt.legend()

plt.tight_layout()
plt.show()

输出:

7. 利用训练的模型进行预测

show_predictions(test_dataset,10)

输出:

可以看出,预测准确率还挺高的,基本能预测出宠物轮廓。

 

 

 

 

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值