365天深度学习训练营-第T10周:数据增强

 我的环境:

  • 语言环境:Python3.11.2
  • 编译器:PyCharm Community Edition 2022.3
  • 深度学习环境:TensorFlow2 

 一、设置数据

1.1 获取数据集

        先初步导入、设置数据

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers,models

data = 'F:/365-7-data'

img_height = 224
img_width = 224
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data,
    validation_split=0.3,
    subset='training',
    seed=123,
    image_size=(img_height,img_width),
    batch_size=batch_size
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data,
    validation_split=0.3,
    subset='validation',
    seed=123,
    image_size=(img_height,img_width),
    batch_size=batch_size
)

        本次导入路径没有使用pathlib库而是直接导入。

        与直接使用文件路径相比,使用pathlib库操作文件路径有以下几个优点:

        1. 跨平台兼容性更好:`pathlib.Path` 可以自动适应不同操作系统的文件路径分隔符,无需手动处理。

        2. 更加安全:使用 `pathlib.Path` 可以避免一些常见的路径操作错误,例如路径拼接时忘记添加分隔符、路径中包含特殊字符等。

        3. 更加可读性高:使用 `pathlib.Path` 可以更加清晰地表达路径的含义,例如 `pathlib.Path('/home/user/data')` 比字符串 `'/home/user/data'` 更加易读。

        4. 更加灵活:`pathlib.Path` 提供了丰富的方法和属性,可以方便地进行路径操作,例如获取文件名、扩展名、父目录等。

1.2 获取测试集

        本次的数据中没有测试集,从验证集中抽取一部分作为测试集。

        使用tf.data.experimental.cardinality获取验证集数据的数量。它返回一个 `tf.data.experimental.Cardinality` 对象,该对象包含了数据集的元素数量信息。

        使用take方法获取数据,使用skip方法将测试集数据移出验证集。

val_batches = tf.data.experimental.cardinality(val_ds)
#取整
test_ds = val_ds.take(val_batches//5)
val_ds = val_ds.skip(val_batches//5)

print('%d'%tf.data.experimental.cardinality(val_ds))
print('%d'%tf.data.experimental.cardinality(test_ds))

1.3 继续配置 

AUTOTUNE = tf.data.AUTOTUNE
#归一化
def preprocessing_image(image,label):
    return (image/255.0,label)

train_ds = train_ds.map(preprocessing_image,num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(preprocessing_image,num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(preprocessing_image,num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

        这里的map方法是TensorFlow中的,而非python自带的。`tf.data.Dataset`对象的`map()`方法可以用于对数据集中的每个元素应用一个函数,返回一个新的数据集。

二、数据增强

        使用tf.keras.layers.experimental.preprocessing.RandomFlip与tf.keras.layers.experimental.preprocessing.RandomRotation进行数据增强。

        前者将图像在水平和垂直方向上随机反转,或者这是随机反转图像。

data_au = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal_and_vertical'),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)
])

        前者表示水平与垂直随机反转,后者为按照0.2的弧度随机进行反转。

        `tf.keras.Sequential` 是 `tf.keras` 模块中的一个类,用于创建顺序模型。它是 `tf.keras.models.Sequential` 的子类。

for images,label in train_ds.take(1):
    for i in range(9):
        image = tf.expand_dims(images[i],0)
        aug = data_au(image)
        ax = plt.subplot(3,3,i+1)
        plt.imshow(aug[0])
        plt.axis('off')
plt.show()

        增强数据可以放在model中,这样在模型进行训练(fit)时,GPU便会帮助加速增强数据。

model = tf.keras.Sequential([
    data_au,
    layers.Conv2D(64,3,activation='relu'),
    layers.Dense(64)
])

        因为data_au是通过tf.keras.Sequential定义的,所有也相当于一个模型,这里直接将模型作为网络层添加到新模型中。

        也可以在数据集中使用map进行增强。

val_ds = val_ds.map(lambda x,y: (data_au(x,training=False),y,num_parallel_calls=AUTOTUNE)

         `num_parallel_calls` 是一个用于控制数据预处理并行度的参数。在数据预处理过程中,通常需要进行一些图像增强、数据标准化、数据裁剪等操作,这些操作可能会比较耗时。为了加快数据预处理的速度,可以使用 TensorFlow 的 `tf.data.Dataset.map()` 函数来对数据进行预处理,并通过 `num_parallel_calls` 参数来指定并行处理的线程数。

三、训练模型

model = tf.keras.Sequential([
    layers.Conv2D(16,3,padding='same',activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(16,3,padding='same',activation='relu'),
    layers.Flatten(),
    layers.Dense(32,activation='relu'),
    layers.Dense(len(class_name))
])

model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

epochs = 10
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

四、总结

        本次学习了数据集的增强方式,使用tf.keras.layers.experimental.preprocessing.RandomFlip与tf.keras.layers.experimental.preprocessing.RandomRotation进行数据增强。并在model或使用数据集的map方法来增强数据。有助于提高模型的准确率。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值