在图像的语义分割中,对图像进行裁剪时,为了达到image和label的统一裁剪,需要将image和label使用concat合并到一起然后进行任意裁剪,再分开。
在这个过程中,如果image为(256,256,3),label为(256,256,1),那么concat(axis=-1)后,concat_image的维度为(256,256,4),这个不难理解。
concat后我们使用tf.image.random_crop对图片进行裁剪,裁剪过后再切片把image和label分开。在分开的过程中,最后一维的索引是从1开始的,而不是0,即第三维是1-4而不是0-3。切片的时候要注意弄错的话会导致后续的数据维度混乱。
如下所示
import tensorflow as tf
import glob
import numpy as np
train_images=glob.glob(r'./dataset/images/train/*/*.png')
val_images=glob.glob(r'./dataset/images/val/*/*.png')
test_images=glob.glob(r'./dataset/images/test/*/*.png')
train_labels=glob.glob(r'./dataset/gtFine/train/*/*_gtFine_labelIds.png')
val_labels=glob.glob(r'./dataset/gtFine/val/*/*_gtFine_labelIds.png')
test_labels=glob.glob(r'./dataset/gtFine/test/*/*_gtFine_labelIds.png')
print(train_images[0:5],train_labels[0:5])
def load_png(path):
image=tf.io.read_file(path)
image=tf.image.decode_png(image,channels=3)
return image
def load_label(path):
label=tf.io.read_file(path)
label=tf.image.decode_png(label,channels=1)
return label
image_path='./dataset/images/train\\aachen\\aachen_000000_000019_leftImg8bit.png'
label_path='./dataset/gtFine/train\\aachen\\aachen_000000_000019_gtFine_labelIds.png'
image=load_png(image_path)
label=load_label(label_path)
concat_image=tf.concat([image,label],axis=-1)
concat_image=tf.image.resize(concat_image,(280,280),method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
crop_image=tf.image.random_crop(concat_image,(256,256,4))
image=crop_image[:,:,:3]
label=crop_image[:,:,3:]
print('crop_image.shape:',crop_image.shape,type(crop_image))
print('image.shape:',image.shape,'label.shape:',label.shape)
print('crop_image[:,:,:0].shape:',crop_image[:,:,:0].shape)
print('crop_image[:,:,:4].shape:',crop_image[:,:,:4].shape)
输出结果
crop_image.shape: (256, 256, 4) <class 'tensorflow.python.framework.ops.EagerTensor'>
image.shape: (256, 256, 3) label.shape: (256, 256, 1)
crop_image[:,:,:0].shape: (256, 256, 0)
crop_image[:,:,:4].shape: (256, 256, 4)