猫狗数据集下载DATA_CAT | Kaggle
有25000张图片
导入相关的库
import warnings
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
warnings.filterwarnings("ignore")
导入数据
PATH = '/kaggle/input/data-cat'
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'val')
BATCH_SIZE = 32
IMG_SIZE = (160, 160)
train_ds = tf.keras.utils.image_dataset_from_directory(train_dir,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)
val_ds = tf.keras.utils.image_dataset_from_directory(validation_dir,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)
数据增强
data_augmentation = keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
]
)
建立模型
inputs = keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = layer