TensorFlow2利用Oxford-IIIT Pets dataset数据集(MobileNetV2预训练模型和pix2pix)完成图像分隔任务

图像分隔就是给图像中的每个像素分配一个标签,图像分隔的任务是训练一个神经网络来输出该图像对每一个像素的掩码。

1. 导入所需的库

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_examples.models.pix2pix.pix2pix as pix2pix
import matplotlib.pyplot as plt
from IPython.display import clear_output

for i in [tf,tfds]:
    print(i.__name__,": ",i.__version__,sep="")

输出:

tensorflow: 2.2.0
tensorflow_datasets: 3.1.0

2. 下载数据集

Oxford-IIIT Pets dataset数据集由37类宠物图像组成,每类大约200张图像每张图像在比例、姿势、亮度等方面有很大差别。该数据集由图像、图像所对应的标签、以及对像素逐一标记的掩码组成。每个像素属于以下三类别之一:

  1. 像素是宠物的一部分
  2. 像素是宠物的轮廓
  3. 以上都不是
dataset, info = tfds.load("oxford_iiit_pet:3.*.*",with_info=True) # 3.0.0版本以上才有图像分隔的标签

输出:

Downloading and preparing dataset oxford_iiit_pet/3.2.0 (download: 773.52 MiB, generated: 774.69 MiB, total: 1.51 GiB) to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0.incomplete24GDUG\oxford_iiit_pet-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=3680.0), HTML(value='')))
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0.incomplete24GDUG\oxford_iiit_pet-test.tfrecord
HBox(children=(FloatProgress(value=0.0, max=3669.0), HTML(value='')))
Dataset oxford_iiit_pet downloaded and prepared to C:\Users\my-pc\tensorflow_datasets\oxford_iiit_pet\3.2.0. Subsequent calls will reuse this data.

3. 数据预处理

对下载的数据做以下预处理:

  1. 图像翻转
  2. 图像像素值归一化到[0,1]
  3. 分隔掩码由{1, 2, 3}变为{0, 1, 2}
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32)/255.0 # 像素值归一化
    input_mask -= 1
    return input_image, input_mask

@tf.function
def load_image_train(datapoint):
    input_image = tf.image.resize(datapoint["image"],(128,128)) # 图像resize
    input_mask = tf.image.resize(datapoint["segmentation_mask"],(128,128))
    
    if tf.random.uniform(()) > 0.5:  # 随机按50%进行水平翻转
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

def load_image_test(datapoint):
    input_image = tf.image.resize(datapoint["image"],(128,128))
    input_mask = tf.image.resize(datapoint["segmentation_mask"],(128,128))
    
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask
train_length = info.splits["train"].num_examples # 按数据集原始的划分数目和比例
batch_size = 64
buffer_size = 1000
steps_per_epoch = train_length//batch_size

train = dataset["train"].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset["test"].map(load_image_test)

train_dataset = train.cache().shuffle(buffer_size).batch(batch_size).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(batch_size)
# 查看图像格式
def display(display_list):
    plt.figure(figsize=(10,10))
    title = ["Input Image","True Mask","Predicted Mask"]
    
    for i in range(len(display_list)):
        plt.subplot(1,len(display_list),i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis("off")
    plt.tight_layout()
    plt.show()

for image, mask in train.take(5):
    sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

输出:

4. 构建模型

这里使用的是一个修改版的U-Net。U-Net由一个编码器(降采样器)和解码器(上采样器)组成。为了学习到更鲁棒的特征和减少训练参数的数量,这里使用了预训练的MobileNet V2模型,使用其中间的输出结果。而解码器使用TensorFlow Examples中的pix2pix。

output_channels = 3 # 输出通道是3的原因是每个像
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值