tf.data处理数据全过程——代码详解

tf.data加载图片并进行数据增强

tf.data为tensorflow官方最为推荐的处理数据的模块。本文通过实例详细叙述使用tf.data处理数据的方法,其中包括读入数据、预处理数据、使用数据增强等,同时本文针对上述过程中涉及到的python 及 tensorflow 的api均进行了详细的说明,以求充分理解并自行完成满足实际需求的代码。本文使用tensorflow版本为v2.6且本文内容在v2.0以上版本均适用。

1. 数据集文件存储形式

本次示例使用花分类数据集,分为trianval两个文件夹,每个文件夹中均含有以5个类别命名的文件夹,文件中为相应的数据集图片。示例图如下

-- data   
   -- train
        -- daisy
        	-- *.jpg 
            ...
        -- dandelion 
        -- roses 
        -- sunflowers
        -- tulips
    -- val
        -- daisy
        -- dandelion 
        -- roses 
        -- sunflowers
        -- tulips
2. 获取数据集路径和类别名称
2. 1 获取类别名称
# class dict
data_class = [cla for cla in os.listdir(train_dir) if 	        		                               		 os.path.isdir(os.path.join(train_dir, cla))]
class_num = len(data_class)
class_dict = dict((value, index) for index, value in enumerate(data_class))

# reverse value and key of dict
inverse_dict = dict((val, key) for key, val in class_dict.items())

说明:

  1. train_dir为训练集文件夹的路径,其中包含以各类别命名的文件夹。所以通过os.listdir()获得train_dir中各个文件夹的名称,也就是各个类别的名称。即**data_class=[daisy, dendelion, roses, sunflowers, tulips]**。这里值得注意data_class列表获取的方法。
  2. class_dict获得类别名对应的序号,inverse_dict获得序号对应类别名。二者均为字典类型。
2.2 获得存放训练集和验证集的数据和标签的列表

后期获得训练集和验证集的迭代器,需要使用tf.data.Dataset.from_tensor_slices((all_images_path, all_imgs_labels))函数获得,其中需要图片和标签数据,图片数据以图片路径组成的列表表示,标签数据以与图片相应的数字列表组成。所以我们需要先构造出满足上述条件的两个列表!

# load train images list
train_image_list = glob.glob(train_dir+"/*/*.jpg")
random.shuffle(train_image_list)
train_num = len(train_image_list)
assert train_num > 0, "cannot find any .jpg file in {}".format(train_dir)
train_label_list = [class_dict[path.split(os.path.sep)[-2]] for path in train_image_list]

# load validation images list
val_image_list = glob.glob(validation_dir+"/*/*.jpg")
random.shuffle(val_image_list)
val_num = len(val_image_list)
assert val_num > 0, "cannot find any .jpg file in {}".format(validation_dir)
val_label_list = [class_dict[path.split(os.path.sep)[-2]] for path in val_image_list]

print("using {} images for training, {} images for validation.".format(train_num,
val_num))

说明:

  1. glob.glob(train_dir+"/*/*.jpg"):通过glob函数获得满足 文件名称条件 的文件根目录列表。具体api的使用方法见博客
  2. train_label_list:获得图片对应的label序号列表时可以利用图片路径。图片路径信息如下:‘D:\YU_Files\workspace\资料\up主网络实现资料\deep-learning-for-image-processing-master\data_set\flower_data\train\sunflowers\5067864967_19928ca94c_m.jpg’。其中sunflowers即为图片对应的类别信息。使用字符串的split方法即可获得。其中注意os.path.sep是操作系统用来分隔路径名组件的字符,这里相当于\
  3. 训练集和验证集的操作方式相同。
3. 加载训练集和验证集
def process_path(img_path, label):
    label = tf.one_hot(label, depth=class_num)
    image = tf.io.read_file(img_path)
    image = tf.io.decode_jpeg(image)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [im_height, im_width])
    return image, label

AUTOTUNE = tf.data.AUTOTUNE

# load train dataset
train_dataset = tf.data.Dataset.from_tensor_slices((train_image_list, train_label_list))
train_dataset = train_dataset.shuffle(buffer_size=train_num)\
.map(process_path, num_parallel_calls=AUTOTUNE)\
.repeat().batch(batch_size).prefetch(AUTOTUNE)

# load train dataset
val_dataset = tf.data.Dataset.from_tensor_slices((val_image_list, val_label_list))
val_dataset = val_dataset.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
.repeat().batch(batch_size)

说明:(以train_dataset为例子

  1. tf.data.Dataset.from_tensor_slices((imgs_path, labels)):创建数据集的常用方法,同时传入图片地址(或图片),和对应的标签,即可获得一个可迭代的数据集train_dataset

  2. train_dataset.map(map_func, num_parallel_calls=None, deterministic=None):对数据集进行变换(transformations)。该函数会对数据集中每一个元素执行map_func处理,并返回处理后的数据集。

    • num_parallel_calls:该参数表示要异步并行处理数据的数量,可以为tf.int64tf.Tensor类型数据。如果不指定,数据将按顺序处理。如果指定为**tf.data.AUTOTUNE**,则会根据可用的CPU动态设置并行处理数据的数量。
    • deterministic:布尔型。当num_parallel_calls确定时,deterministic将会控制数据转换过程中处理数据的顺序。具体参考。一般不指定(默认参数)。
  3. train_dataset.shuffle(buffer_size=train_num):shuffle()方法可以充分打乱数据,其中的buffer_size参数需要设置与数据集大小一致,以保证数据集被充分打乱。

  4. train_dataset.repeat(count=None):重复数据集,重复次数为count传入的参数。

    • count:可以为tf.int64或tf.Tensor类型数据,设置数据集重复的次数。若不设置(默认)或设为-1,则会无限期的重复数据集。
  5. train_dataset.batch(batch_size, drop_remainder=False, num_parallel_calls=None, deterministic=None):结合连续数量的数据为一个batch。**drop_remainder**设置为True时,当数据集的个数不能整除一个batch中数据的个数时,则会舍弃最后几个不够一个batch的数据。其余两个参数与map函数中相同。

  6. train_dataset.prefetch(buffer_size):通常在创建一个迭代数据集最后调用prefetch方法。可以使得处理当前数据的同时准备后续数据,尽快提供batch。当buffer_size设定为tf.data.AUTOTUNE时,将会自动调整缓冲区的大小(buffer size)。

  7. process path函数详解:

    • tf.one_hot(label, depth=class_num):将数据变为one-hot类型,depth为一个标量,决定one-hot的维度。在分类数据集上,则为类别的个数。

    • tf.io.read_file(img_path):返回包括文件所有内容的tensor, tensor的dtype为“string”。所以其一般为数据流的,得到的数据还需要进行解码。

    • tf.image/io.decode_jpeg(image, channel=0):示例中给的解码操作使用的是tf.image工具包,但tensorflow_v2.6版本所有解码操作已经都移入tf.io模块中。该函数作用为解码jpeg(jpg)图片为uint8的tensor。

      • channel参数可以指定为0, 1, 3。默认为0,表示安装原始jpg的通道数获取图片,如果设为1,则会解码为一个灰度图片。
      • 其他参数参考
    • tf.image.convert_image_dtype(image, tf.float32):该api实现对img数据类型进行转换,支持的数据类型有uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64, bfloat16当使用浮点数类型时,图片中数据的值将在 ( 0 , 1 ] (0,1] (01]之间;使用整型时,图片中数据的值将在 ( 0 , M a x ] (0, Max] (0Max]之间,Max为该整型的最大值。

    • image = tf.image.resize(image, [im_height, im_width],method=ResizeMethod.BILINEAR):resize图片到特定的尺寸,可以指定method参数确定resize使用的方法。各方法说明见官网

      当原始图片和resize后的图片长宽比 比例不一致时会出现扭曲,如果不想出现这个问题,可以使用**tf.image.resize_with_pad()**函数,其会保证resize后的图片部分长宽比与原始图像相同,其余部分通过padding 0来填充。

4. 数据增强

tensorflow支持的数据增强的方法主要有两种:1.使用**tf.image模块中提供的方法。2.使用keras中的预处理层**,比如tf.keras.layers.RandomFlip,tf.keras.layers.Resizing等。

由于上面我们已经使用过tf.image模块,也大致了解了其配合tf.data使用方法,同时**相对于使用keras预处理层来说,tf.image更具有灵活性和可控制性,且配合tf.data使用可以自定义数据增强方法。**所以我们先从tf.image实现数据增强进行说明。

4. 1 使用tf.image进行数据增强

为了方便观察数据增强API的效果,我们定义一个函数,用来对比原始图片和增强后的图片。

def visualize(original, augmented):
  fig = plt.figure()
  plt.subplot(1,2,1)
  plt.title('Original image')
  plt.imshow(original)

  plt.subplot(1,2,2)
  plt.title('Augmented image')
  plt.imshow(augmented)

tf.image提供了固定增强和随机增强等两大类增强方式,本文只列举每个类中的几种。具体api详见官网

Part1 固定参数增强

  1. 左右翻转(Flip an image)

    flipped = tf.image.flip_left_right(image)
    visualize(image, flipped)
    

在这里插入图片描述

  1. 改变图片饱和度(Saturate an image)

    saturated = tf.image.adjust_saturation(image, 3)
    visualize(image, saturated)
    

    tf.image.adjust_saturation函数需要传入图片和饱和度参数(3)。
    在这里插入图片描述

Part2 随机增强

注意:在tensorflow2.6版本中,tf.image含有两种随机增强api:tf.image.random*tf.image.stateless_random*官方更推荐使用后者。

  1. 随机亮度(Randomly change image brightness)

    for i in range(3):
      seed = (i, 0)  # tuple of size (2,)
      stateless_random_brightness = tf.image.stateless_random_brightness(
          image, max_delta=0.95, seed=seed)
      visualize(image, stateless_random_brightness)
    

    函数说明:tf.image.stateless_random_brightness需要传入三个参数

    • image:图像数组,
    • max_delta:亮度参数,指定后亮度变换范围为 [ − m a x _ d e l t a , m a x _ d e l t a ) [-max\_delta, max\_delta) [max_delta,max_delta)
    • seed:seed为随机种子,为两元素的元组类型。相同种子将会产生相同的变化规则。
      在这里插入图片描述
      在这里插入图片描述
      在这里插入图片描述
  2. 随机裁剪(Randomly crop an image)

    for i in range(3):
      seed = (i, 0)  # tuple of size (2,)
      stateless_random_crop = tf.image.stateless_random_crop(
          image, size=[210, 300, 3], seed=seed)
      visualize(image, stateless_random_crop)
    

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

函数说明:tf.image.stateless_random_crop中size为指定crop图像的大小。seed仍为size为(2,)的元组。

Part3 数据集应用增强

在上文加载训练集和验证集部分中我们已经介绍了dataset的map方法,可以发现通过map函数可以灵活的按照自己的要求处理数据,只需要定义一个预处理数据的函数。因此数据增强方法也可以定义在预处理函数中。预处理函数示例如下:

def resize_and_rescale(image, label):
  image = tf.cast(image, tf.float32)
  image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
  image = (image / 255.0)
  return image, label

先定义一个基本的数据预处理函数,包括图像数据的类型转化,resize和数据归一化(若使用tf.image.convert_image_dtype(image, tf.float32)则会自动将数据变为float32并归一化[0,1)之间)。

def augment(image_label, seed):
  image, label = image_label
  image, label = resize_and_rescale(image, label)
  image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)
  # Make a new seed.
  new_seed = tf.random.experimental.stateless_split(seed, num=1)[0, :]
  # Random crop back to the original size.
  image = tf.image.stateless_random_crop(
      image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed)
  # Random brightness.
  image = tf.image.stateless_random_brightness(
      image, max_delta=0.5, seed=new_seed)
  image = tf.clip_by_value(image, 0, 1)
  return image, label

再定义一个数据增强函数,其中引用了resize_and_rescale函数对图像进行基本预处理。再增加亿点点细节,就可以随心所欲的增加数据增强方法了。

函数说明:

  1. tf.image.resize_with_crop_or_pad:对图片进行resize,该函数相较于tf.image.resize会根据原始图片比例适当增加padding或者crop后再进行resize。具体参考

  2. tf.random.experimental.stateless_split(seed, num):该函数需要传入一个RNG形式的seed(一个shape为2的tensor,类型为int32,或int64。比如seed=[1,2]),和返回seeds的个数num。该函数会返回一个shape为[num, 2]的新的seed。在augment函数代码中我们对该tf.random.experimental.stateless_split的返回值进行了切片操作,就是因为其返回值为[1, 2]的二维数组,而需要传入随机增强中的seed的shape应为(2, )的一维数组。

  3. **tf.image.stateless_random_croptf.image.stateless_random_brightness**都为随机增强函数,在上述示例中都有提到,这里不再叙述。

  4. tf.clip_by_value(image, 0, 1):将image数据限制在[0, 1]之间,调用该函数是为了防止数据增强过程中导致图像的数据不在[0, 1]范围之内。

Part4 随机种子的产生

在上文中我们提到了随机增强函数,可以了解到调用随机增强函数需要传入一个seed参数,传入相同的seed在其他参数相同的情况下会产生相同的增强效果。**所以我们需要创造一个seed迭代器,让其每次调用产生不同的seed。**官方提供了tf.random.Generator,我们只需要创建一个tf.random.Generator.from_seed实例,通过每次调用make_seeds方法即可得到随机的种子。具体实现见下文代码。

  1. 创建tf.random.Generator.from_seed实例,传入一个初始化种子值(这里为123),和产生随机数的方法,这里设为philox,其他产生随机数的方法,参见官网
# Create a generator.
rng = tf.random.Generator.from_seed(123, alg='philox')
  1. 调用make_seed方法,产生随机种子,传入前面定义的augment函数中。
# Create a wrapper function for updating seeds.
def f(x, y):
  seed = rng.make_seeds(2)[0]
  image, label = augment((x, y), seed)
  return image, label

说明:

  • make_seeds(count=1):该方法会返回一个shape为[2, count]的tensor,dtype为int64。

到此随机增强部分就全部明了了,接着使用map函数对数据进行f(x,y)就可以构造出一个经过增强后的数据集了。

train_datasets = 
    train_datasets
    .shuffle(1000)
    .map(f, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
5.数据集训练

训练部分使用tensorflow提供的高级api即可,model.fit可以直接接受上文中得到的train_datasets进行训练。训练部分代码详解将不在本文叙述,可以参考官网,或即将更新的下一篇博文。下面仅附上参考代码。

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
	loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
	metrics=["accuracy"])

callbacks = 		[tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex_{epoch}.h5',
	save_best_only=True,
	save_weights_only=True,
	monitor='val_loss')]

history = model.fit(x=train_dataset,
	steps_per_epoch=train_num // batch_size,
	epochs=epochs,
	validation_data=val_dataset,
	validation_steps=val_num // batch_size,
	callbacks=callbacks)
  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值