UNET图像语义分割模型简介
代码
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
# 显存自适应分配
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu,True)
gpu_ok = tf.test.is_gpu_available()
print("tf version:", tf.__version__)
print("use GPU", gpu_ok) # 判断是否使用gpu进行训练
获取训练数据及目标值
# 获取train文件下所有文件中所有png的图片
img = glob.glob("G:/BaiduNetdiskDownload/cityscapes/leftImg8bit/train/*/*.png")
train_count = len(img)
img[:5],train_count
# 获取gtFine/train文件下所有文件中所有_gtFine_labelIds.png的图片
label = glob.glob("G:/BaiduNetdiskDownload/cityscapes/gtFine/train/*/*_gtFine_labelIds.png")
index = np.random.permutation(len(img)) # 创建一个随即种子,保障image和label 随机后还是一一对应的
img = np.array(img)[index] # 对训练集图片进行乱序
label = np.array(label)[index]
获取测试数据
# 获取val文件下所有文件中所有png的图片
img_val = glob.glob("G:/BaiduNetdiskDownload/cityscapes/leftImg8bit/val/*/*.png")
# 获取gtFine/val文件下所有文件中所有_gtFine_labelIds.png的图片
label_val = glob.glob("G:/BaiduNetdiskDownload/cityscapes/gtFine/val/*/*_gtFine_labelIds.png")
test_count = len(img_val)
img_val[:5],test_count,label_val[:5],len(label_val)
创建数据集
dataset_train = tf.data.Dataset.from_tensor_slices((img,label))
dataset_val = tf.data.Dataset.from_tensor_slices((img_val,label_val))
# 创建png的解码函数
def read_png(path):
img = tf.io.read_file(path)
img = tf.image.decode_png(img,channels=3)
return img
# 创建png的解码函数
def read_png_label(path):
img = tf.io.read_file(path)
img = tf.image.decode_png(img,channels=1)
return img
# 数据增强
def crop_img(img