基于python深度学习垃圾图像分类识别系统
文章目录
基于Python深度学习垃圾图像分类识别系统
设计一个基于Python的深度学习垃圾图像分类识别系统。这个系统将包括以下主要部分:
- 数据预处理
- 模型构建与训练
- 图形用户界面 (GUI)
- 实时检测功能
1. 系统需求和库安装
首先,我们需要安装必要的库:
pip install tensorflow keras numpy matplotlib PyQt5 opencv-python scikit-learn
2. 数据集准备
我们将使用Kaggle的垃圾分类数据集:https://www.kaggle.com/asdasdasasdas/garbage-classification
数据集包含6个类别:
- 玻璃 (glass)
- 纸张 (paper)
- 硬纸板 (cardboard)
- 塑料 (plastic)
- 金属 (metal)
- 垃圾 (trash)
# data_preprocessing.py
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def load_data(data_dir, img_size=(224, 224), batch_size=32):
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2
)
train_generator = datagen.flow_from_directory(
data_dir,
target_size=img_size,
batch_size=batch_size,
class_mode='categorical',
subset='training'
)
val_generator = datagen.flow_from_directory(
data_dir,
target_size=img_size,
batch_size=batch_size,
class_mode='categorical',
subset='validation'
)
return train_generator, val_generator
3. 模型构建 (使用迁移学习 - MobileNetV2)
# model_builder.py
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
def build_model(input_shape=(224, 224, 3), num_classes=6):
# 使用MobileNetV2作为基础模型
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
# 冻结基础模型的层
for layer in base_model.layers:
layer.trainable = False
# 添加自定义顶层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
x = Dense(512, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
# 编译模型
model.compile(optimizer=Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
4. 训练模型
# train.py
from data_preprocessing import load_data
from model_builder import build_model
def train_model(data_dir, epochs=10, batch_size=32):
train_generator, val_generator = load_data(data_dir, batch_size=batch_size)
model = build_model()
model.summary()
history = model.fit(
train_generator,
steps_per_epoch=len(train_generator),
validation_data=val_generator,
validation_steps=len(val_generator),
epochs=epochs
)
# 保存模型
model.save('garbage_classification_model.h5')
return model, history
if __name__ == '__main__':
data_dir = 'path_to_your_dataset'
train_model(data_dir)
5. 创建图形用户界面 (PyQt5)
# gui_interface.py
import sys
import cv2
import numpy as np
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QPushButton, QVBoxLayout, QWidget, QFileDialog
from PyQt5.QtGui import QPixmap, QImage, QFont
from PyQt5.QtCore import Qt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
class GarbageClassificationApp(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("垃圾图像分类识别系统")
self.setGeometry(100, 100, 800, 600)
# 加载模型
self.model = load_model('garbage_classification_model.h5')
self.class_names = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
# 创建UI
self.init_ui()
def init_ui(self):
# 创建中央部件和布局
central_widget = QWidget()
self.setCentralWidget(central_widget)
layout = QVBoxLayout()
# 标题标签
title_label = QLabel("垃圾图像分类识别系统")
title_label.setAlignment(Qt.AlignCenter)
title_label.setFont(QFont("Arial", 18, QFont.Bold))
layout.addWidget(title_label)
# 图像显示标签
self.image_label = QLabel()
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setStyleSheet("border: 1px solid black")
layout.addWidget(self.image_label)
# 结果标签
self.result_label = QLabel("请选择一张图片进行预测")
self.result_label.setAlignment(Qt.AlignCenter)
self.result_label.setFont(QFont("Arial", 14))
layout.addWidget(self.result_label)
# 按钮
self.btn_select = QPushButton("选择图片")
self.btn_select.clicked.connect(self.select_image)
layout.addWidget(self.btn_select)
# 设置中央布局
central_widget.setLayout(layout)
def select_image(self):
"""选择图片并进行预测"""
filename, _ = QFileDialog.getOpenFileName(self, "选择图片", "", "Image Files (*.png *.jpg *.jpeg)")
if filename:
# 显示原始图片
pixmap = QPixmap(filename).scaled(
self.image_label.width(), self.image_label.height(),
Qt.KeepAspectRatio, Qt.SmoothTransformation)
self.image_label.setPixmap(pixmap)
# 预处理图片
image = cv2.imread(filename)
image = cv2.resize(image, (224, 224))
image = image.astype("float") / 255.0
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
# 进行预测
preds = self.model.predict(image)
i = np.argmax(preds[0])
label = self.class_names[i]
prob = preds[0][i] * 100
# 显示结果
self.result_label.setText(f"预测结果: {label} ({prob:.2f}%)")
if __name__ == '__main__':
app = QApplication(sys.argv)
window = GarbageClassificationApp()
window.show()
sys.exit(app.exec_())
6. 实时摄像头检测功能
# real_time_detection.py
from gui_interface import GarbageClassificationApp
import cv2
import numpy as np
from PyQt5.QtCore import QTimer
class RealTimeDetectionApp(GarbageClassificationApp):
def __init__(self):
super().__init__()
self.cap = None
self.timer = QTimer()
self.timer.timeout.connect(self.update_frame)
def init_ui(self):
super().init_ui()
# 修改按钮功能为打开/关闭摄像头
self.btn_select.setText("打开摄像头")
self.btn_select.clicked.disconnect()
self.btn_select.clicked.connect(self.toggle_camera)
def toggle_camera(self):
"""切换摄像头状态"""
if self.cap is None or not self.cap.isOpened():
self.cap = cv2.VideoCapture(0)
self.timer.start(20) # 每20毫秒更新一次
self.btn_select.setText("关闭摄像头")
else:
self.timer.stop()
self.cap.release()
self.btn_select.setText("打开摄像头")
self.image_label.clear()
self.result_label.setText("请选择一张图片进行预测")
def update_frame(self):
"""更新视频帧"""
ret, frame = self.cap.read()
if ret:
# 显示原始视频
h, w, ch = frame.shape
bytes_per_line = ch * w
qt_image = QImage(frame.data, w, h, bytes_per_line, QImage.Format_RGB888).rgbSwapped()
pixmap = QPixmap.fromImage(qt_image).scaled(
self.image_label.width(), self.image_label.height(),
Qt.KeepAspectRatio, Qt.SmoothTransformation)
self.image_label.setPixmap(pixmap)
# 预处理图片
resized_frame = cv2.resize(frame, (224, 224))
resized_frame = resized_frame.astype("float") / 255.0
resized_frame = img_to_array(resized_frame)
resized_frame = np.expand_dims(resized_frame, axis=0)
# 进行预测
preds = self.model.predict(resized_frame)
i = np.argmax(preds[0])
label = self.class_names[i]
prob = preds[0][i] * 100
# 显示结果
self.result_label.setText(f"预测结果: {label} ({prob:.2f}%)")
if __name__ == '__main__':
app = QApplication(sys.argv)
window = RealTimeDetectionApp()
window.show()
sys.exit(app.exec_())
7. 系统使用说明
-
训练模型:
- 下载并解压垃圾图像数据集
- 运行
train.py
文件开始训练模型 - 训练完成后会生成
garbage_classification_model.h5
文件
-
图像分类识别:
- 运行
gui_interface.py
启动应用程序 - 点击"选择图片"按钮,选择需要识别的垃圾图片
- 程序会显示预测结果及其置信度
- 运行
-
实时检测:
- 运行
real_time_detection.py
启动实时检测程序 - 点击"打开摄像头"按钮开始实时检测
- 程序会持续显示摄像头画面并实时预测当前场景中的垃圾类型
- 运行
8. 模型优化建议
- 数据增强:可以增加更多的数据增强技术来提高模型泛化能力
- 模型调整:尝试不同的预训练模型(如ResNet、Inception等)进行迁移学习
- 超参数调优:调整学习率、批次大小等参数以获得更好的性能
- 多尺度检测:对于不同尺寸的垃圾物体,可以实现多尺度检测
- 部署优化:使用TensorRT或OpenVINO等工具对模型进行优化以便部署
这是一个完整的垃圾图像分类识别系统实现方案,包含了从数据预处理、模型构建、训练到图形界面开发的完整流程。
以上为示例代码,仅供参考
关键模型:VGG19➕DenseNet121➕ResNeXt101
包含内容:数据集➕ppt➕文档➕代码