【学习记录】tensorflow图像数据增强

问题描述

使用 Keras 预处理层,对图片进行例如亮度,对比度,灰度,水平旋转,垂直旋转等的操作。降低神经网络模型对数据的过拟合。

数据集

采用tensorflow自带数据集beans

Beans 是使用智能手机相机在田间拍摄的豆类图像数据集。它由3个类别组成:2个疾病类别和健康类别。描述的疾病包括角叶斑病和豆锈病。数据由乌干达国家作物资源研究所 (NaCRRI) 的专家进行注释,并由 Makerere AI 研究实验室收集。

代码分析

1.在使用tfds.load数据集会出现无法获取google的token信息,从而无法下载数据集的情况

# 添加代码,取消认证
# 获得 Google 身份 
tfds.core.utils.gcs_utils._is_gcs_disabled = True

2.下载数据集并划分训练,验证,测试集

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras import layers

# # 查看数据集
# tfds.list_builders()

# 下载数据集合
(train_ds,val_ds,test_ds),metadata = tfds.load(
    'beans',
    split=['train[:80%]','train[80%:90%]','train[90%:]'],
    with_info=True, 
    as_supervised=True

)

# 数据集合的类别
num_classes = metadata.features['label'].num_classes
print(num_classes)
#可视化
get_label_name = metadata.features['label'].int2str

image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))

3.1 使用keras预处理层-调整大小和重新缩放

'''调整大小和重新缩放'''
IMG_SIZE = 180
resize_and_rescale = tf.keras.Sequential([
    layers.Resizing(IMG_SIZE,IMG_SIZE),
    layers.Rescaling(1./255)
])

# 可视化
result = resize_and_rescale(image)
_ = plt.imshow(result)
# 验证像素是否在[0, 1]范围内:
print("Min and max pixel values:", result.numpy().min(), result.numpy().max())

3.1 使用keras预处理层-数据增强

'''数据增强'''
data_augmentation = tf.keras.Sequential([
  layers.RandomFlip("horizontal_and_vertical"),
  layers.RandomRotation(0.2),
])

# 可视化
# Add the image to a batch.
image = tf.cast(tf.expand_dims(image, 0), tf.float32)
plt.figure(figsize=(10, 10))
for i in range(3):
  augmented_image = data_augmentation(image)
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(augmented_image[0])
  plt.axis("off")

4.1 将预处理之后的图片传入模型--使预处理成为模型的一部分

'''使用keras预处理层---选项一:使预处理成为模型的一部分'''
model = tf.keras.Sequential([
    resize_and_rescale,
    data_augmentation,
    layers.Conv2D(16,3,padding="same",activation='relu'),
    layers.MaxPooling2D(),


    # 剩下的模型结构
    #在这种情况下,有两点需要注意:

    # 1. 数据增强将在设备上运行,与您的其余层同步,并受益于 GPU 加速。

    # 2. 当您使用 导出模型model.save时,预处理层将与模型的其余部分一起保存。
    # 如果您稍后部署此模型,它将自动标准化图像(根据您的层的配置)。
    # 这可以使您不必重新实现该逻辑服务器端的工作。
# ])

4.2 将预处理之后的图片传入模型--将预处理层应用于数据集

'''使用keras预处理层---选项二:将预处理层应用于您的数据集'''
aug_ds = train_ds.map(
  lambda x, y: (resize_and_rescale(x, training=True), y))
# 使用这种方法,您Dataset.map可以创建一个数据集,该数据集会产生成批的增强图像。在这种情况下:
  # 1.数据扩充将在 CPU 上异步发生,并且是非阻塞的。
  # 您可以使用数据预处理在 GPU 上对模型的训练进行重叠Dataset.prefetch

  # 2.在这种情况下,当您调用Model.save. 在保存模型或在服务器端重新实现它们之前
  #,您需要将它们附加到模型中。训练后,您可以在导出前附加预处理层。

5. 采用预处理应用于数据集的方法,在训练集上进行数据增强

'''将预处理层应用于数据集'''
batch_size = 32
# 数据接口
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds, shuffle=False, augment=False):
 # .map可以创建一个数据集,该数据集会产生成批的增强图像.
  ds = ds.map(lambda x, y: (resize_and_rescale(x), y), 
              num_parallel_calls=AUTOTUNE)

  if shuffle:
    ds = ds.shuffle(1000)

# 多个批次取数据集
  ds = ds.batch(batch_size)

# 只在训练集上使用数据增强
  if augment:
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), 
                num_parallel_calls=AUTOTUNE)

  # 在所有數據集上使用緩衝預取。
  return ds.prefetch(buffer_size=AUTOTUNE)

# 在训练集上进行数据增强和打乱顺序
train_ds = prepare(train_ds,shuffle=True,augment=True)
val_ds = prepare(val_ds)
test_ds = prepare(test_ds)

6.构建模型,训练模型

'''构建模型'''
model = tf.keras.Sequential([
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),

  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),

  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),

  layers.Conv2D(8, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),

  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  
  layers.Dense(num_classes)
])
# model.summary()

# 编译模型
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
# 训练模型
epochs=10
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
# 评估模型
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)

参考链接

https://tensorflow.google.cn/tutorials/images/data_augmentation

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值