快速复现——基于 tensorflow 的花卉识别
参考链接:基于TensorFlow的花卉识别
前置知识了解:卷积神经网络(CNN)详细介绍及其原理详解
本文可以帮助你在 Ubuntu 20.04 环境下,快速部署一个较为简单的深度学习模型,你也可以扩展该模型,识别更多种类的花卉。本文提供全部源码和关键的代码分析。
一、环境准备
1.1 Anaconda 下载
-
本地环境:ubuntu 20.04, 无GPU,无CUDA
-
下载:通过 wget 命令下载或者去官网下载
wget -P /tmp https://repo.anaconda.com/archive/Anaconda3-2024.02-1-Linux-x86_64.sh
-
安装:打开 Anaconda 安装包所在路径的终端,运行脚本
bash Anaconda3-2024.02-1-Linux-x86_64.sh
运行后看到输出一些安装许可,默认位置,最后是 conda init ,输入 yes 后安装程序将 conda 的路径添加到系统的 PATH(查看 ./bashrc 文件 )
1.2 安装PyCharm
下载安装包,直接安装,创建桌面快捷方式即可,新建项目时选择对应的 python 解释器即可
1.3 安装 tensorflow-CPU
如果支持 GPU,建议安装 tensorflow-GPU 版本,训练模型更快,由于我的是虚拟机,故使用 tensorflow-CPU
-
创建 conda 虚拟环境
conda create -n tensorflow_py_3.8 python=3.8.18
-
激活环境
conda activate tensorflow_py_3.8
-
安装 tensorflow-cpu 、Pyqt5、matplotlib等依赖包
conda install tensorflow==2.12.0 sudo apt-get install mesa-utils pip install PyQt5 pip install scikit-learn pip install matplotlib
-
测试安装
python #启动 python
import tensorflow as tf a = tf.constant(10) b = tf.constant(12) result = a + b print(result.numpy()) #输出22即为安装成功
有警告是正常的,不会运行不起来,是因为 CPU 支持 SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA 这些指令集的加速,需要用 bazel 重新编译 tensorflow 源代码并安装,但 CPU 在这快也快不过 GPU1,且本文只提供一些较为简单的项目部署和代码分析,故忽略这些信息即可。
二、数据集准备
本次数据集采用 tensorflow 的官方数据集 flower_photo ,共有五种花卉,分别是雏菊、蒲公英、玫瑰、向日葵和郁金香)的图片,并有对应类别的标识(daisy、dandelion、roses、sunflowers、tulips)
对应样本数(对数据集进行分类,90%训练集,10%验证集):
二、模型设计思路
使用 TensorFlow 环境下的卷积神经网络 CNN 技术,通过 CNN 对数据集进行对应的训练,建立相关模型,再使用模型对相对应花卉进行识别。其中神经网络的建立使用 TensorFlow 2.x 的 Keras 的 Api 进行搭建,绘制损失函数和准确率曲线对模型训练效果进行评价,训练完成保存为 mode.h5 文件储存,在预测数据时读取 model 文件加载 model ,再使用 model 进行预测数据,并使用 Pyqt5 工具设计一个简洁的 GUI 界面进行人机交互,能够自定义预测数据集图片。
三、模型设计
第三步模型设计有项目全部代码,也可通过终端 git 克隆到本地,不过源码有点问题,我做出了一些改进,比如增加线程池数目,优化图片识别逻辑,可以自己建立项目结构,然后直接按照我给的代码复制粘贴即可
克隆源码命令如下(自选):
git clone https://gitee.com/steven_L1047/tensor-flow.git
直接构建项目结构如图:构建红字标记的文件,model.h5(这个不用构建,训练时生成) ,其他是多余文件
项目使用的 python 解释器是之前用 conda 创建好的 tensorflow 虚拟环境下的 python3.8
3.1 train.py 训练模型的代码分析
导入一些用到的 Python 库和模块
# matplotlib.pyplot:用于绘制图表和可视化数据,更好评估模型性能
import matplotlib.pyplot as plt
# tensorflow:用于构建深度学习模型,tensorflow 是一个深度学习框架
import tensorflow as tf
# tensorflow.keras.layers 、tensorflow.keras.models:用于构建神经网络模型的层和模型。
from tensorflow.keras import layers, models
# numpy:用于进行数值计算和数组操作,模型训练涉及到大量的矩阵运算
import numpy as np
from keras.utils.np_utils import to_categorical
from sklearn.metrics import confusion_matrix
定义用于训练模型的函数 train()
def train():
# 分批处理,制定每次训练数据的尺寸
batch_size = 32
img_height = 180
img_width = 180
# 调整TensorFlow的线程池,设置线程数为 4
config = tf.compat.v1.ConfigProto(inter_op_parallelism_threads=4)
session = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(session)
# 作用:使用TensorFlow的Keras API中的image_dataset_from_directory函数从指定目录加载图像数据集,并将其划分为训练集和验证集。
# 返回值:该函数返回一个 tf.data.Dataset 对象,可以直接用于模型训练。
# paramter1: --path--存储图像数据集的目录路径
# paramter2: --validation_split--将数据集划分为训练集和验证集时,验证集所占的比例
# paramter3:--subset--加载的数据集类型,training/validation
# paramter4:--seed--随机种子,用于数据集划分的随机性,同一个数据集划分的话用同一个种子,保证加载的训练集和验证集的划分一致
# paramter5:--image_size--加载的图像尺寸,这里是指定图像的高度和宽度
# paramter6:--batch_size--批量大小,即每次训练时从数据集中取出的样本数量
# 训练集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
'./flower_photos/',
validation_split=0.1,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
# 验证集
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
'./flower_photos/',
validation_split=0.1,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
# 从数据集中获取分类名
class_names = train_ds.class_names
print(class_names)
# 表示数据集中的类别数量为5类,可调整
num_classes = 5
# 定义常量AUTOTUNE,用于自动调整参数。
AUTOTUNE = tf.data.experimental.AUTOTUNE
# 处理数据集,并加载至运行内存,使用了 cache()缓存数据集到内存中,加速数据读取,prefetch() 预取操作提前加载数据以减少训练时的等待时间,shuffle() 随机重排增加数据的随机性,参数 1000 表示每次从数据集中随机抽取1000个样本进行训练。
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
# 神经网络构建
# models 是 Keras 中的模型模块,Sequential 是一种模型类型,表示按顺序堆叠各种神经网络层。通过创建一个 Sequential 模型,可以方便地按顺序添加各种层,并构建神经网络模型。
model = models.Sequential([
# 输入层:将数据归一化,并设置input_shape输入(图片高,宽,RGB 通道)
layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),
# 卷积层1:卷积核数目为16,卷积核为3*3,激活函数为relu,并设置input_shape为(180,180,3),使用卷积的目的是从输入图片中提取特征
# 池化层1:采用最大池化操作,使用2*2采样,池化层的目的是降低了每个特征映射的维度,但是保留最重要的信息
layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),
layers.MaxPooling2D((2, 2)),
# 再添加 3 个卷积层和池化层,每个卷积层的卷积核数目逐渐增加,用于提取更加复杂的特征。
layers.Conv2D(32, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
# Flatten层:连接卷积层与全连接层,把多维的输入一维化
layers.Flatten(),
# 全连接层:神经元数目 units 设置为128,即输出维度为128,,激活函数为relu,全连接层对上一层的神经元进行全部连接,实现特征的非线性组合,进行特征进一步提取
layers.Dense(128, activation='relu'),
# 输出层,输出层包含num_classes个神经元,用于输出模型的预测结果。
layers.Dense(num_classes)
])
model.summary() # 打印网络结构
# 模型编译:优化器 optimizer 选adam (一种常用的优化算法,具有较快的收敛速度和较好的性能表现),损失函数 loss 选 SparseCategoricalCrossentropy(一种常用的多分类损失函数,适用于标签为整数形式的分类问题。参数from_logits=True表示模型的输出是未经过softmax激活的原始logits值,损失函数会在内部进行softmax计算),指标 metrics 选择准确率 accuracy(指定了评估指标为准确率(accuracy)。在训练过程中,模型会根据准确率来评估模型的性能,以便监控模型的训练过程)
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 训练模型:指定训练集 train_ds ,验证集 validation_data 为 val_ds ,迭代10次
# 返回值:model.fit()方法会返回一个history对象,其中包含了训练过程中的损失值和评估指标值的历史记录,可以用于后续的可视化和分析
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=10
)
# 训练结束,保存model为model.h5,HDF5格式(.h5文件):保存模型的格式,可以保存模型的架构、权重和优化器状态等信息,方便后续重用模型、分享模型、部署模型
model.save('model.h5')
# 使用evaluate评价模型,并打印准确率,verbose=2:指定了输出详细程度为2,即在评估过程中会输出每个batch的评估结果
test_loss, test_acc = model.evaluate(val_ds, verbose=2)
print("验证准确率为:", test_acc)
# 数据可视化演示,展示用于训练的图片
plt.figure(figsize=(20, 10))
for images, labels in train_ds.take(1):
for i in range(20):
ax = plt.subplot(5, 10, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
plt.show()
#从 history 中获取一些模型的参数分析
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
#模型训练历史可视化,获取准确率和损失值并绘制函数
plt.figure()
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.ylim([0, 1])
plt.title('Training and Validation Accuracy')
plt.show()
plt.plot(loss, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Loss')
plt.show()
#绘制混淆矩阵
val_label = []
count = 0
for images, labels in val_ds:
val_label.extend(labels)
count = count+1
one_hot_train_label3 = to_categorical(val_label)
plot_confuse(model, val_ds, one_hot_train_label3)
return True
下面的代码是绘制混淆矩阵的两个函数,混淆矩阵是评估分类模型性能的一种工具,也是一种直观反映模型的识别成功率的图表。它展示了模型在不同类别上的预测结果与实际标签之间的对应关系。混淆矩阵的行代表实际类别,列代表预测类别,每个单元格中的值表示实际属于该行类别但被预测为该列类别的样本数量。
从上面展示的混淆矩阵图图表可以看出,一个合格的模型在矩阵的对角线上的值是比较高的,这说明该模型能够较为准确的识别出所给的图片种类,当然,通过图片还可以看出 tulips 被识别为 roses 的情况还是偏多,所以还要对模型进行调优和改进。
labels = ['dasiy', 'dandelion', 'roses', 'sunflowers', 'tulips']
def plot_confusion_matrix(cm, target_names, title='Confusion matrix', cmap=plt.cm.Greens, normalize=True):
accuracy = np.trace(cm) / float(np.sum(cm))
misclass = 1 - accuracy
proportion = []
length = len(cm)
for i in cm:
for j in i:
temp = j / (np.sum(i))
proportion.append(temp)
pshow = []
for i in proportion:
pt = "%.2f%%" % (i * 100)
pshow.append(pt)
proportion = np.array(proportion).reshape(length, length) # reshape(列的长度,行的长度)
pshow = np.array(pshow).reshape(length, length)
if cmap is None:
cmap = plt.get_cmap('Blues')
plt.figure(figsize=(15, 12))
plt.imshow(proportion, interpolation='nearest', cmap=cmap) # 按照像素显示出矩阵
plt.title(title)
plt.colorbar()
if target_names is not None:
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
thresh = cm.max() / 1.5 if normalize else cm.max() / 2
iters = np.reshape([[[i, j] for j in range(length)] for i in range(length)], (cm.size, 2))
for i, j in iters:
if (i == j):
plt.text(j, i, pshow[i, j], horizontalalignment="center", fontsize=13, color="white" if cm[i, j] > thresh else "black")
else:
plt.text(j, i, pshow[i, j], horizontalalignment="center", fontsize=13,color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label accuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
plt.tight_layout()
plt.show()
# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
predictions = model.predict(x_val, batch_size=32)
predicted_classes = np.argmax(predictions, axis=1)
truelabel = np.argmax(y_val, axis=1) # 将one-hot转化为label
conf_mat = confusion_matrix(y_true=truelabel, y_pred=predicted_classes)
plt.figure()
plot_confusion_matrix(conf_mat, target_names=labels, title='Confusion Matrix')
3.2 validate.py 调用模型识别图片的代码分析
这个 .py 文件会加载用 train.py 训练好的模型(model.h5 文件),然后将前端得到的要识别图片的路径传入模型进行识别,并返回识别结果。
import os
import tensorflow as tf
from tensorflow import keras
import numpy as np
flower_dict = {0: 'dasiy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
img_height = 180
img_width = 180
# 定义加载模型的函数
def load_model_if_exists(model_path):
if os.path.exists(model_path):
model = keras.models.load_model(model_path, compile=True)
return model
else:
return None
# 加载模型
model = load_model_if_exists('./model.h5')
@tf.function
def predict_function(data):
return model(data)
# 根据方法参数path加载图片数据转为array类型,由于维度问题需要扩展1维,使用numpy的expand_dims方法将数据由3维扩展为4维,然后使用model.predict方法将图片数据作为参数调用,返回result结果(index)对应flower_dict相对应index,返回对应index的种类名作为结果
def validate(path):
if model is None:
return "Model file not found. Please make sure the model file exists."
data = keras.preprocessing.image.load_img(path, target_size=(img_height, img_width))
data = keras.preprocessing.image.img_to_array(data)
data = np.expand_dims(data, axis=0)
data = np.vstack([data])
result = np.argmax(predict_function(data))
return flower_dict[result]
四、交互设计
重点在于前面的模型设计,界面设计和启动函数 Main 就不用废话介绍了。
4.1 mainUI.py 和 main.ui 界面样式
直接用就好
main.ui 内容
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>MainWindow</class>
<widget class="QMainWindow" name="MainWindow">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>800</width>
<height>600</height>
</rect>
</property>
<property name="windowTitle">
<string>MainWindow</string>
</property>
<property name="layoutDirection">
<enum>Qt::LeftToRight</enum>
</property>
<property name="autoFillBackground">
<bool>true</bool>
</property>
<widget class="QWidget" name="centralwidget">
<widget class="QLabel" name="label">
<property name="geometry">
<rect>
<x>350</x>
<y>30</y>
<width>91</width>
<height>31</height>
</rect>
</property>
<property name="font">
<font>
<family>04b_21</family>
<pointsize>14</pointsize>
</font>
</property>
<property name="layoutDirection">
<enum>Qt::LeftToRight</enum>
</property>
<property name="text">
<string>花卉识别</string>
</property>
</widget>
<widget class="QPushButton" name="btn_validate">
<property name="geometry">
<rect>
<x>430</x>
<y>310</y>
<width>81</width>
<height>31</height>
</rect>
</property>
<property name="text">
<string>识别</string>
</property>
</widget>
<widget class="QPushButton" name="btn_train">
<property name="geometry">
<rect>
<x>300</x>
<y>310</y>
<width>81</width>
<height>31</height>
</rect>
</property>
<property name="text">
<string>训练</string>
</property>
</widget>
<widget class="QLabel" name="label_2">
<property name="geometry">
<rect>
<x>270</x>
<y>380</y>
<width>71</width>
<height>31</height>
</rect>
</property>
<property name="font">
<font>
<family>04b_21</family>
<pointsize>12</pointsize>
</font>
</property>
<property name="text">
<string>训练结果</string>
</property>
</widget>
<widget class="QLabel" name="label_3">
<property name="geometry">
<rect>
<x>270</x>
<y>440</y>
<width>71</width>
<height>31</height>
</rect>
</property>
<property name="font">
<font>
<family>04b_21</family>
<pointsize>12</pointsize>
</font>
</property>
<property name="text">
<string>识别结果</string>
</property>
</widget>
<widget class="QLabel" name="pic">
<property name="geometry">
<rect>
<x>260</x>
<y>80</y>
<width>261</width>
<height>191</height>
</rect>
</property>
<property name="autoFillBackground">
<bool>false</bool>
</property>
<property name="styleSheet">
<string notr="true">background-color: rgb(255, 255, 255);</string>
</property>
<property name="text">
<string/>
</property>
</widget>
<widget class="QTextEdit" name="trainResult">
<property name="geometry">
<rect>
<x>360</x>
<y>380</y>
<width>151</width>
<height>31</height>
</rect>
</property>
<property name="html">
<string><!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd">
<html><head><meta name="qrichtext" content="1" /><style type="text/css">
p, li { white-space: pre-wrap; }
</style></head><body style=" font-family:'SimSun'; font-size:9pt; font-weight:400; font-style:normal;">
<p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:12pt;">未训练</span></p></body></html></string>
</property>
</widget>
<widget class="QTextEdit" name="validateResult">
<property name="geometry">
<rect>
<x>360</x>
<y>440</y>
<width>151</width>
<height>31</height>
</rect>
</property>
</widget>
</widget>
<widget class="QMenuBar" name="menubar">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>800</width>
<height>23</height>
</rect>
</property>
</widget>
<widget class="QStatusBar" name="statusbar"/>
</widget>
<resources/>
<connections>
<connection>
<sender>btn_train</sender>
<signal>clicked()</signal>
<receiver>MainWindow</receiver>
<slot>train()</slot>
<hints>
<hint type="sourcelabel">
<x>339</x>
<y>340</y>
</hint>
<hint type="destinationlabel">
<x>391</x>
<y>308</y>
</hint>
</hints>
</connection>
<connection>
<sender>btn_validate</sender>
<signal>clicked()</signal>
<receiver>MainWindow</receiver>
<slot>validate()</slot>
<hints>
<hint type="sourcelabel">
<x>475</x>
<y>338</y>
</hint>
<hint type="destinationlabel">
<x>559</x>
<y>312</y>
</hint>
</hints>
</connection>
</connections>
<slots>
<slot>train()</slot>
<slot>validate()</slot>
</slots>
</ui>
mainUI.py 文件的内容
# -*- coding: utf-8 -*-
# Form implementation generated from reading ui file 'main.ui'
#
# Created by: PyQt5 UI code generator 5.15.4
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again. Do not edit this file unless you know what you are doing.
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtGui import QFont
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(800, 600)
MainWindow.setLayoutDirection(QtCore.Qt.LeftToRight)
MainWindow.setAutoFillBackground(True)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.label = QtWidgets.QLabel(self.centralwidget)
self.label.setGeometry(QtCore.QRect(350, 30, 91, 31))
font = QtGui.QFont()
font.setFamily("04b_21")
font.setPointSize(14)
self.label.setFont(font)
self.label.setLayoutDirection(QtCore.Qt.LeftToRight)
self.label.setObjectName("label")
self.btn_validate = QtWidgets.QPushButton(self.centralwidget)
self.btn_validate.setGeometry(QtCore.QRect(430, 310, 81, 31))
self.btn_validate.setObjectName("btn_validate")
self.btn_train = QtWidgets.QPushButton(self.centralwidget)
self.btn_train.setGeometry(QtCore.QRect(300, 310, 81, 31))
self.btn_train.setObjectName("btn_train")
self.label_2 = QtWidgets.QLabel(self.centralwidget)
self.label_2.setGeometry(QtCore.QRect(270, 380, 71, 31))
font = QtGui.QFont()
font.setFamily("04b_21")
font.setPointSize(12)
self.label_2.setFont(font)
self.label_2.setObjectName("label_2")
self.label_3 = QtWidgets.QLabel(self.centralwidget)
self.label_3.setGeometry(QtCore.QRect(270, 440, 71, 31))
font = QtGui.QFont()
font.setFamily("04b_21")
font.setPointSize(12)
self.label_3.setFont(font)
self.label_3.setObjectName("label_3")
self.pic = QtWidgets.QLabel(self.centralwidget)
self.pic.setGeometry(QtCore.QRect(260, 80, 261, 191))
self.pic.setAutoFillBackground(False)
self.pic.setStyleSheet("background-color: rgb(255, 255, 255);")
self.pic.setText("")
self.pic.setObjectName("pic")
self.trainResult = QtWidgets.QTextEdit(self.centralwidget)
self.trainResult.setGeometry(QtCore.QRect(360, 380, 151, 31))
self.trainResult.setObjectName("trainResult")
self.validateResult = QtWidgets.QTextEdit(self.centralwidget)
self.validateResult.setGeometry(QtCore.QRect(360, 440, 151, 31))
self.validateResult.setObjectName("validateResult")
self.validateResult.setFont(QFont('Arial', 12))
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 23))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
self.btn_train.clicked.connect(MainWindow.train)
self.btn_validate.clicked.connect(MainWindow.validate)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.label.setText(_translate("MainWindow", "花卉识别"))
self.btn_validate.setText(_translate("MainWindow", "识别"))
self.btn_train.setText(_translate("MainWindow", "训练"))
self.label_2.setText(_translate("MainWindow", "训练结果"))
self.label_3.setText(_translate("MainWindow", "识别结果"))
self.trainResult.setHtml(_translate("MainWindow", "未训练"))
4.2 Main.py 主函数
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import QFileDialog, QMainWindow
from mainUI import Ui_MainWindow
from Train import train
import validate
from validate import validate
class Main (QMainWindow,Ui_MainWindow):
def __init__(self):
super().__init__()
self.setupUi(self)
def train(self):
if( train() ):
self.trainResult.setText("训练完成")
validate.load_model_if_exists()
def validate(self):
filename, _ = QFileDialog.getOpenFileName(self, '打开图片')
print(filename)
img = QtGui.QPixmap(filename).scaled(self.pic.width(), self.pic.height())
self.pic.setPixmap(img)
self.validateResult.setText(validate(filename))
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
window = Main()
window.show()
sys.exit(app.exec_())
将三、四步的代码导入项目,运行 Main.py,训练模型
五、自定义扩展
通过数据集的增加和模型的训练,使模型能够识别更多的花朵类型,在此提供扩展示例给大家参考,新的数据集可以在我资源中下载(for free),提供四种花朵类型。
5.1 添加数据集
下载新的花朵数据集,并添加到项目文件中
5.2 修改代码
对 train.py 中内容的修改如下,修改识别花朵的总类别数,新增四种,改为 9 种,改 labels
# 表示数据集中的类别数量为5类,可调整
num_classes = 9
labels = ['dasiy', 'dandelion', 'lotus','passiflora','pink','roses', 'sunflowers', 'tropaeolum','tulips']
对 validate.py 的内容做修改,使其能输出对应的花朵类型
flower_dict = {0:'dasiy', 1:'dandelion', 2:'lotus',3:'passiflora',4:'pink',5:'roses',
6:'sunflowers',7:'tropaeolum',8:'tulips'}
扩展后的模型混淆矩阵
总结
-
运行的时候会出现 warning ,但不影响正常运行。
-
虚拟机的配置要给高一些,我的是 4 个处理器(2核心),还是会出现 pycharm 闪退,或者是一到训练迭代就卡,直接退出程序,多运行几次就好了。
-
如果训练的模型准确率不高:可以试着通过以下手段进行模型调优
- 调整模型架构:尝试不同的模型架构,包括增加或减少层级、调整每层的神经元数量、改变激活函数等。有时候更复杂的模型能够更好地拟合数据,但也要注意避免过拟合。比如增加层级是比较简单的操作,只需要在 models.Sequential 中增加一层卷积层和池化层。
- 调整超参数:调整学习率、批量大小、优化器等超参数,以找到最佳的组合。通过网格搜索或随机搜索来寻找最优的超参数组合。
- 数据增强:对训练数据进行数据增强,如旋转、翻转、缩放等操作,以增加数据的多样性,帮助模型更好地泛化。
- 调整数据预处理:尝试不同的数据预处理方法,如标准化、归一化、特征缩放等,以确保数据的质量和一致性。