第十周 tensorflow中的数据增强

一、本周学习内容:

本周主要是学习数据增强函数和展示其效果

数据增强

方法一:嵌入model中

方法二:使用Dataset数据集进行增强

方法三:自定义数据增强函数

二、前言

我们的数据图片共有3400张,两个类别,类别即为文件夹名。
类别包括:[‘cat’, ‘dog’]

三、电脑环境

电脑系统:Windows 10
语言环境:Python 3.8.8
编译器:Pycharm 2021.1.3
深度学习环境:TensorFlow 2.8.0,keras 2.8.0
显卡及显存:RTX 3070 8G

四、前期准备

1、导入相关依赖项

import pathlib

import numpy as np
from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.layers import MaxPool2D

2、设置GPU(我下载的tensorflow-gpu 默认使用GPU)

只使用GPU

if gpus:
    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

使用cpu和gpu
os.environ[“CUDA_VISIBLE_DEVICES”] = “-1”

3、加载数据集和展示

(1)、数据预处理

data_dir = 'dataset'
data_dir = pathlib.Path(data_dir)
print(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

batch_size = 32
img_height=224
img_width = 224


# 读取图片
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset='training',
    seed=469,
    image_size=(img_height,img_width),
    batch_size=batch_size,
    label_mode='categorical'
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset='validation',
    seed=469,
    image_size=(img_height,img_width),
    batch_size=batch_size,
    label_mode='categorical'
)

# 由于我们数据在存在训练集和验证集  我将验证集部分拆为测试集
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds = val_ds.take(val_batches//5)
val_ds = val_ds.skip(val_batches//5)
print("Number of validation batches:%d "%tf.data.experimental.cardinality(val_ds))
print("Number of validation batches:%d "%tf.data.experimental.cardinality(test_ds))

class_num = train_ds.class_names
print(class_num)


# 配置数据集
AUTOTUNE = tf.data.AUTOTUNE
# 创建数据归一化函数
def preprocess_image(image,label):
    return (image/255.0,label)
train_ds = train_ds.map(preprocess_image,num_parallel_calls = AUTOTUNE)
val_ds = val_ds.map(preprocess_image,num_parallel_calls = AUTOTUNE)
test_ds = test_ds.map(preprocess_image,num_parallel_calls = AUTOTUNE)
# 打乱数据  并载入内存
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

(2)、数据展示

图片展示

# 可视化数据
plt.figure(figsize=(15,10))  #创建画布 画布大小为w15h10
for images,labels in train_ds.take(1):
    for i in range(batch_size):
        ax = plt.subplot(5,8,i+1)
        plt.imshow(images[i])
        plt.title(class_num[np.argmax(labels[i])])
        plt.axis("off")
plt.show()  # pycharm加入这行才能显示

在这里插入图片描述

(3)、数据增强及其效果展示

image = tf.expand_dims(images[0],0)
data_augmeentation = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

在这里插入图片描述

五、在CNN网络加入数据增强

两种选其中一种即可

方法一:嵌入到model中

好处是 你使用GPU训练时 这部分也能得到GPU加速

model = tf.keras.Sequential([
	data_augmentation,
	layers.Conv2D(16,3,padding='same',activation='relu'),
	layers.MaxPooling2D(),
])

注意:模型训练时才会进行增强,模型评估和预测时不会增强

方法二:在Dataset数据集中进行数据增强

batch_size=32
AUTOTUNE=tf.data.AUTOTUNE
def prepare(ds):
ds = ds.map(lambd x,y:(data_augmentation(x,training=True),y),num_parallel_cells=AUTOTUNE)
train_ds = prepare(train_ds)

自定义数据增强函数

def aug_img(image):
    seed = (random.randint(0,9),0)
    # 随机改变图片对比度
    stateless_random_brightness= tf.image.stateless_random_contrast(image,lower=0.1,upper=1.0,seed=seed)
    return stateless_random_brightness
image = tf.expand_dims(images[1]*255,0)
print("Min and max pixel values:",image.numpy().min(),image.numpy().max())
plt.figure(figsize=(8,8))
for i in range(9):
    augmented_image = aug_img(image)
    ax = plt.subplot(3,3,i+1)
    plt.imshow(augmented_image[0].numpy().astype('uint8'))
    plt.axis('off')
plt.show()

在这里插入图片描述

以上就是我本周的学习内容
在这里插入图片描述

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

降花绘

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值