使用 TensorFlow 和 Keras 构建卷积神经网络进行熊猫图片分类


前言

在本篇博客中,我将介绍如何使用 TensorFlow 和 Keras 构建一个卷积神经网络(CNN)模型,用于识别熊猫图片和非熊猫图片。我们将通过图像数据增强技术来扩展训练数据,并使用早停和模型检查点回调来优化模型的训练过程。

数据集下载链接:https://download.csdn.net/download/weixin_48839391/89753951


1.项目背景

在这次项目中,我们的目标是通过卷积神经网络(CNN)来分类图片,判断图片中的动物是否为熊猫。我们使用了 Keras 的 ImageDataGenerator 来进行数据增强,确保模型能够从有限的图片集中学到更具代表性的特征。最终,训练好的模型会保存为 best_model.h5。

2.数据准备

我们的数据集分为两部分:

训练集:用于训练模型,图片经过数据增强后会传入模型。
验证集:用于在每个训练周期结束时评估模型的性能,确保模型没有过拟合。
我们使用了如下路径存储数据:

train_dir = ‘data/train’
validation_dir = ‘data/validation’
通过 os.walk() 函数,我们可以统计训练和验证集中的图片数量。

3.数据准备

为了让模型更具泛化能力,避免过拟合,我们使用了数据增强技术。在训练过程中,图片会经过一系列随机变换,例如旋转、平移、缩放等操作:

train_datagen = ImageDataGenerator(
    rescale=1.0 / 255.0,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

4.构建模型

我们选择了一个标准的卷积神经网络架构,包含三个卷积层和池化层,最终通过全连接层输出一个概率值,判断图片属于哪一类(熊猫或非熊猫)。模型的结构如下:

model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
    MaxPooling2D(2, 2),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Flatten(),
    Dense(512, activation='relu'),
    Dense(1, activation='sigmoid')
])

5.测试

import cv2
import numpy as np
import tensorflow as tf


def predict_image(image_path, model):
    img = cv2.imread(image_path)
    img = cv2.resize(img, (150, 150))
    img = np.expand_dims(img, axis=0) / 255.0

    prediction = model.predict(img)
    return "Panda" if prediction[0][0] >= 0.5 else "Not Panda"


if __name__ == '__main__':
    # 加载模型
    model = tf.keras.models.load_model('best_model.h5')

    # 测试图片
    result = predict_image('data/images/panda/Image_13.jpg', model)
    print(result)

    result = predict_image('data/images/panda/Image_8.jpg', model)
    print(result)

    result = predict_image('data/images/not panda/Image_6.jpg', model)
    print(result)

    result = predict_image('data/images/not panda/Image_34.jpg', model)
    print(result)

结果:
在这里插入图片描述
可以知道模型训练结果良好,基本符合预期

6.总结

通过这次实践,我们学习了如何使用卷积神经网络进行图像分类,并通过数据增强技术增强了模型的泛化能力。我们还使用了早停和模型检查点功能,确保训练过程能够得到最优的模型。该模型可以用于解决其他类似的图像分类任务,只需根据实际需求调整数据集即可。

附录(全部代码)

tensorflow 采用2.11.0版本

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import os

# 数据路径
train_dir = 'data/train'
validation_dir = 'data/validation'

# 统计图片数量
num_train_images = sum([len(files) for r, d, files in os.walk(train_dir)])
num_validation_images = sum([len(files) for r, d, files in os.walk(validation_dir)])

# 图像数据增强和预处理
train_datagen = ImageDataGenerator(
    rescale=1.0 / 255.0,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

validation_datagen = ImageDataGenerator(rescale=1.0 / 255.0)

# 生成训练和验证数据
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(150, 150),
    batch_size=20,
    class_mode='binary'
)

validation_generator = validation_datagen.flow_from_directory(
    validation_dir,
    target_size=(150, 150),
    batch_size=20,
    class_mode='binary'
)

# 构建模型
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
    MaxPooling2D(2, 2),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Flatten(),
    Dense(512, activation='relu'),
    Dense(1, activation='sigmoid')
])

# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 回调函数
checkpoint = ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss', mode='min')
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

# 训练模型
history = model.fit(
    train_generator,
    steps_per_epoch=num_train_images // 20,  # 调整 steps_per_epoch
    epochs=10,
    validation_data=validation_generator,
    validation_steps=num_validation_images // 20,  # 调整 validation_steps
    callbacks=[checkpoint, early_stopping],
    use_multiprocessing=False,  # 禁用多线程
    workers=1  # 限制工作线程数为1
)

# 最好的模型会自动保存为 best_model.h5
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值