pyqt5结合keras+tensorflow实现深度学习训练过程GUI界面编写

pyqt5结合keras+tensorflow实现深度学习训练过程GUI界面编写

主要实现功能

第一次写博客,主要是想记录下最近踩的坑。最近想做一个集深度学习训练过程与缺陷检测过程为一体的界面,但是中间遇到许多问题,其中解决耗时最长的问题就是如何将深度学习训练过程实时显示在GUI界面的Textbrowser上,实现Textbrowser作为控制台输出的功能。

直接上代码

这里放的是GUI运行核心代码,其他的代码我将上传到CSDN下载中,有需要的小伙伴可以去下载,地址:https://download.csdn.net/download/weixin_42532587/12352345

import ctypes
import win32con
import sys
from PyQt5.QtWidgets import QMainWindow, QApplication, QDialog, QFileDialog, QMessageBox
from PyQt5 import QtCore, QtGui
from PyQt5.QtCore import QThread, pyqtSignal
from mainwindow import Ui_MainWindow
from Model_training import Ui_Dialog
from Detection import Ui_Dialog1
import global_var as gl
from model_train import model_training
from model_prediction import predict, prediction
import pandas as pd

class EmittingStream(QtCore.QObject):
    textWritten = QtCore.pyqtSignal(str)
    def write(self, text):
        self.textWritten.emit(str(text))
    def flush(self):  # real signature unknown; restored from __doc__
        """ flush(self) """
        pass


class MainUI(QMainWindow, Ui_MainWindow):
    def __init__(self):
        super(MainUI, self).__init__()
        self.setupUi(self)
        self.pushButton_Training.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Detection.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Tuichu.setStyleSheet('color:red')
        self.pushButton_Tuichu.clicked.connect(self.close)
        self.Exit.triggered.connect(self.close)
        # self.pushButton_Detection.clicked.connect()


class Training_Dialog(QDialog, Ui_Dialog):
    def __init__(self):
        super(Training_Dialog, self).__init__()
        self.setupUi(self)
        self.pushButton_training.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_validation.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Start.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Stop.setStyleSheet('background:rgb(255, 0, 0)')
        self.comboBox_BS.addItems(['2', '4', '8', '16', '32', '64', '128'])
        self.comboBox_EP.addItems(['1', '20', '50', '100'])
        self.comboBox_LR.addItems(['0.1', '0.01', '0.001', '0.0001', '0.00001'])
        self.radioButton_AlexNet.setChecked(True)
        self.pushButton_Start.setEnabled(False)
        self.pushButton_training.clicked.connect(self.openfile)
        self.pushButton_validation.clicked.connect(self.openfile1)
        sys.stdout = EmittingStream(textWritten=self.normalOutputWritten)
        sys.stder = EmittingStream(textWritten=self.normalOutputWritten)
        self.pushButton_Start.clicked.connect(self.run_training)
        self.pushButton_Stop.clicked.connect(self.stop_training)
        self.my_thread = MyThread()  # 实例化线程对象

    def hyper_para(self):
        Epoch = self.comboBox_EP.currentText()
        gl.Epoch = int(Epoch)
        print('迭代次数为 %d' % gl.Epoch)
        batch_size = self.comboBox_BS.currentText()
        gl.batch_size = int(batch_size)
        print('批量尺寸为 %d' % gl.batch_size)
        Learning_rate = self.comboBox_LR.currentText()
        gl.learning_rate = float(Learning_rate)
        print('学习率为 %f' % gl.learning_rate)

    def stop_training(self):
        self.my_thread.is_on = False
        ret = ctypes.windll.kernel32.TerminateThread(  # @UndefinedVariable
            self.my_thread.handle, 0)
        print('终止训练', self.my_thread.handle, ret)

    def openfile(self):
        directory = QFileDialog.getExistingDirectory(self, "请选择文件夹路径",
                                                     "F:/Deep_CarrotNet/Carrot_resize_clear_split_clear")
        gl.gl_str_i = directory
        if len(gl.gl_str_i1) == 0:
            QMessageBox.critical(self, '提示', '请选择正确文件夹')

        print('成功加载训练文件', '训练文件夹所在位置:%s' % gl.gl_str_i)

    def openfile1(self):
        directory = QFileDialog.getExistingDirectory(self, "请选择文件夹路径",
                                                     "F:/Deep_CarrotNet/Carrot_resize_clear_split_clear")
        gl.gl_str_i1 = directory
        if len(gl.gl_str_i1) == 0:
            QMessageBox.critical(self, '提示', '请选择正确文件夹')
        else:
            self.pushButton_Start.setEnabled(True)
        print('成功加载验证文件', '验证文件夹所在位置:%s' % gl.gl_str_i1)

    def i_count(self):
        if self.radioButton_CarrotNet.text() == 'CarrotNet':
            if self.radioButton_CarrotNet.isChecked() == True:
                gl.gl_int_i = 2
                print('model is CarrotNet')

            elif self.radioButton_AlexNet.text() == 'AlexNet':
                if self.radioButton_AlexNet.isChecked() == True:
                    gl.gl_int_i = 1
                    print('model is AlexNet')

    def run_training(self):
        self.pushButton_Start.setEnabled(False)
        self.textBrowser.clear()
        self.i_count()
        self.hyper_para()
        if gl.gl_str_i == 'one':
            QMessageBox.critical(self, '错误', '请加载训练图片')
            self.my_thread.is_on = False
        elif gl.gl_str_i1 == 'one':
            QMessageBox.critical(self, '错误', '请加载验证图片')
            self.my_thread.is_on = False
        else:
            self.my_thread.is_on = True
        self.my_thread.start()  # 启动线程
        self.pushButton_Start.setEnabled(True)

    def normalOutputWritten(self, text):
        """Append text to the QTextEdit."""
        # Maybe QTextEdit.append() works as well, but this is how I do it:
        cursor = self.textBrowser.textCursor()
        cursor.movePosition(QtGui.QTextCursor.End)
        cursor.insertText(text)
        self.textBrowser.setTextCursor(cursor)
        self.textBrowser.ensureCursorVisible()


class MyThread(QThread):  # 线程类
    # my_signal = pyqtSignal(str)  # 自定义信号对象。参数str就代表这个信号可以传一个字符串

    def __init__(self):
        super(MyThread, self).__init__()
        # self.count = 0
        self.is_on = True

    def run(self):  # 线程执行函数
        self.handle = ctypes.windll.kernel32.OpenThread(  # @UndefinedVariable
            win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId()))
        while self.is_on:
            model_training(gl.gl_int_i, gl.gl_str_i, gl.gl_str_i1, gl.Epoch,
                           gl.batch_size, gl.learning_rate)
            self.is_on = False


class EmittingStream1(QtCore.QObject):
    textWritten = QtCore.pyqtSignal(str)

    def write(self, text):
        self.textWritten.emit(str(text))

    def flush(self):  # real signature unknown; restored from __doc__
        """ flush(self) """
        pass


class Detection_Dialog(QDialog, Ui_Dialog1):
    def __init__(self):
        super(Detection_Dialog, self).__init__()
        self.setupUi(self)
        self.pushButton_start_detection.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_model.setStyleSheet('background:rgb(255, 0, 0)')
        self.pushButton_picture.setStyleSheet('background:rgb(255, 0, 0)')
        self.radioButton.setChecked(True)
        self.pushButton_model.setEnabled(True)
        self.pushButton_picture.setEnabled(False)
        self.pushButton_start_detection.setEnabled(False)
        self.pushButton_exit.setEnabled(False)
        self.pushButton_save.setEnabled(False)

        sys.stdout = EmittingStream1(textWritten=self.normalOutputWritten1)
        sys.stder = EmittingStream1(textWritten=self.normalOutputWritten1)
        # print('请先选择逐批检测还是逐个检测')
        self.pushButton_model.clicked.connect(self.message)
        self.pushButton_model.clicked.connect(self.load_moad)
        self.pushButton_picture.clicked.connect(self.load_image)
        self.my_thread1 = My_Thread1()  # 实例化线程对象
        self.pushButton_start_detection.clicked.connect(self.detection)
        self. pushButton_save.clicked.connect(self.save_result)
        self.pushButton_exit.clicked.connect(self.close)

    def save_result(self):
        path = QFileDialog.getExistingDirectory(self, "请选择文件路径")
        data = pd.DataFrame(gl.Y)
        data.to_csv(path + '/' + 'detection_result.csv', index=True)


    def message(self):
        QMessageBox.question(self, '提示', '请先选择逐批检测还是逐个检测')
        self.pushButton_model.setEnabled(True)
        self.pushButton_picture.setEnabled(True)
        self.pushButton_start_detection.setEnabled(True)
        self.pushButton_exit.setEnabled(True)

    def load_image(self):
        if self.radioButton.text() == '逐批检测':
            if self.radioButton.isChecked() == True:
                directory1 = QFileDialog.getExistingDirectory(self, "请选择文件路径")
                gl.gl_str_i3 = directory1
                print('成功导入检测文件', '检测文件所在位置:%s' % gl.gl_str_i3)
            elif self.radioButton_2.text() == '逐个检测':
                if self.radioButton_2.isChecked() == True:
                    fname, _ = QFileDialog.getOpenFileName(self, '选择图片', 'c:\\', 'Image files(*.jpg *.gif *.png)')
                    gl.gl_str_i4 = fname
                    print('成功导入检测图片', '检测文件所在位置:%s' % gl.gl_str_i4)
                else:
                    print('请正确选择检测文件路径')


    def load_moad(self):
        directory = QFileDialog.getExistingDirectory(self, "请选择文件路径")
        gl.gl_str_i2 = directory
        print('成功加载模型', '模型所在位置:%s' % gl.gl_str_i2)

    def normalOutputWritten1(self, text):
        """Append text to the QTextEdit."""
        # Maybe QTextEdit.append() works as well, but this is how I do it:
        cursor1 = self.textBrowser1.textCursor()
        cursor1.movePosition(QtGui.QTextCursor.End)
        cursor1.insertText(text)
        self.textBrowser1.setTextCursor(cursor1)
        self.textBrowser1.ensureCursorVisible()

    def detection(self):
        self.pushButton_start_detection.setEnabled(False)
        if self.radioButton.text() == '逐批检测':
            if self.radioButton.isChecked() == True:
                gl.i = 0
            elif self.radioButton_2.text() == '逐个检测':
                if self.radioButton_2.isChecked() == True:
                    gl.i = 1

        self.my_thread1.start()  # 启动线程
        self.pushButton_start_detection.setEnabled(True)
        self.pushButton_save.setEnabled(True)


class My_Thread1(QThread):
    def __init__(self):
        super(My_Thread1, self).__init__()

    def run(self):  # 线程执行函数
        print('测试开始')
        if gl.i == 0:
            prediction(gl.gl_str_i2, gl.gl_str_i3)
        else:
            predict(gl.gl_str_i2, gl.gl_str_i4)


if __name__ == "__main__":
    app = QApplication(sys.argv)
    main = MainUI()
    Training = Training_Dialog()
    Detection = Detection_Dialog()
    main.pushButton_Training.clicked.connect(Training.show)
    main.pushButton_Detection.clicked.connect(Detection.show)
    main.pushButton_Tuichu.clicked.connect(Training.close)
    main.pushButton_Tuichu.clicked.connect(Detection.close)
    main.Exit.triggered.connect(Training.close)
    main.Exit.triggered.connect(Detection.close)
    main.show()
    sys.exit(app.exec_())

重要代码

这里是关于如何将深度学习训练过程实时显示到GUI的Textbrowser上

class EmittingStream(QtCore.QObject):
    textWritten = QtCore.pyqtSignal(str)

    def write(self, text):
        self.textWritten.emit(str(text))

    def flush(self):  # real signature unknown; restored from __doc__
        """ flush(self) """
        pass

一定要加上flush函数的定义,之前在CSDN上找了很久,都没有这行,导致GUI界面上的Textbrowsers只能输出深度学习训练过程的第一行,不能实现实时刷新的功能,加上这个定义就可以完美解决

sys.stdout = EmittingStream(textWritten=self.normalOutputWritten)
sys.stder = EmittingStream(textWritten=self.normalOutputWritten)
class MyThread(QThread):  # 线程类
    # my_signal = pyqtSignal(str)  # 自定义信号对象。参数str就代表这个信号可以传一个字符串

    def __init__(self):
        super(MyThread, self).__init__()
        # self.count = 0
        self.is_on = True

    def run(self):  # 线程执行函数
        self.handle = ctypes.windll.kernel32.OpenThread(  # @UndefinedVariable
            win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId())) # 是为了后面结束进程使用
        while self.is_on:
            model_training(gl.gl_int_i, gl.gl_str_i, gl.gl_str_i1, gl.Epoch,
                           gl.batch_size, gl.learning_rate)
            self.is_on = False
def stop_training(self):
        self.my_thread.is_on = False
        ret = ctypes.windll.kernel32.TerminateThread(  # @UndefinedVariable
            self.my_thread.handle, 0)
        print('终止训练', self.my_thread.handle, ret)

self.handle = ctypes.windll.kernel32.OpenThread( # @UndefinedVariable
win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId())) # 是为了后面结束进程使用
def stop_training(self): 终止训练过程

出错解决

用前面的代码理论上是可以实现实时显示深度学习训练过程的,但我在刚开始使用时,总会出现== finished with exit code -1073740791 (0xC0000409)==,当把 sys.stdout = EmittingStream1(textWritten=self.normalOutputWritten1)
sys.stder = EmittingStream1(textWritten=self.normalOutputWritten1)这两行注释掉时程序可以正常运行,只不过内容没有输出到textbrowser上。在网上搜了一大圈,也没有发现适合我程序的,最后才发现是keras的版本和Tensorflow的版本不匹配造成的,但是之前不在GUI内运行不报错,在GUI框架下运行就会报错,==最终选择Keras2.2.5,tensorflow1.14.0 ==解决了问题,但是运行程序时会出现一大串警告,不过不影响最终结果

最终效果

在这里插入图片描述

未解决问题

我这个有两个子界面,每个子界面都有一个Textbrowser,而且都想达到实时刷新的效果,但是当同时使用时会出现两个Textbrowser内容相互干扰的现象。哪位大神知道如何玩解决的话,还望不吝赐教

补充说明

本界面还使用了全局变量实现不同函数之间的互相传值,具体方法是先建个global_var.py文件,将需要传值的参数预先定义。此后各个文件import使用就行了

# coding=utf-8
# 在别的文件使用方法:
# import global_var_model as gl
#  gl.gl_int_i += 4,可以通过访问和修改gl.gl_int_i来实现python的全局变量,或者叫静态变量访问
# gl.gl_int_i
import numpy as np
gl_int_i = 1  # 这里的gl_int_i是最常用的用于标记的全局变量
gl_str_i = 'one'
gl_str_i1 = 'one'
gl_str_i2 = 'one'
gl_str_i3 = 'one'
gl_str_i4 = 'one'
batch_size = 1
Epoch = 1
learning_rate = 0.1
i = 0
Y = np.array([])

写在最后

第一次写博客,语言也不怎么精炼,文学功底不行,希望大家将就着看,整个GUI的全部代码我将在后续上传到CSDN上。当然这篇博客也借鉴了很多前人的经验,在此表示感谢

  • 26
    点赞
  • 165
    收藏
    觉得还不错? 一键收藏
  • 52
    评论
### 回答1: 手写数字识别是深度学习在计算机视觉领域的一项经典任务,可以使用PyQt5TensorFlow Keras框架来实现。这种任务可以通过卷积神经网络(CNN)来完成。 首先,需要下载一个手写数字图像数据集,例如MNIST数据集。然后,可以使用TensorFlow Keras框架来搭建一个简单的CNN模型,来对图像进行分类。这个CNN模型可以包含一些卷积层、池化层、扁平层和全连接层来实现对手写数字图像的分类。 接下来,使用PyQt5编写一个简单的GUI界面,提供用户手动输入数字图像的功能。GUI界面可以提供一个画布来让用户手动在上面绘制数字,然后对这个数字图像进行预测和分类。 具体实现时,可以结合PyQt5的信号和槽机制,将用户手动绘制的数字图像与CNN模型进行关联。当用户完成数字图像的绘制后,程序可以自动进行图像分类,并输出数字的识别结果。 总之,PyQt5TensorFlow Keras框架提供了一个完整的工具链,用于实现手写数字识别的任务。开发者可以使用这些工具和技术来实现更加复杂的图像识别和分析任务。 ### 回答2: 手写数字识别是深度学习中的一个常见问题,而PyQt5则是一个流行的Python图形界面开发框架,可以将模型的结果以可视化的方式展示给用户。因此,使用PyQt5TensorFlow-Keras搭建一个手写数字识别的应用程序是很有实际应用价值的。下面简单介绍一下实现步骤。 首先,我们需要一个手写数字数据集,可以使用MNIST数据集。通过使用TensorFlow-Keras的API,我们可以快速地构建一个CNN模型,并在训练数据上进行训练。 接下来,我们需要使用PyQt5构建GUI界面,这里可以使用QWidget框架。我们需要构建一个画布,允许用户手写数字,然后将用户手写的图像输入到CNN模型中进行预测。 在这里,我们可以使用QPainter来绘图,它可以使用户绘制完整的数字。在预测数字时,我们需要对图像进行一些预处理,例如将其大小调整为网络需要的输入尺寸,并将其转换为灰度图像。 在模型训练完毕之后,我们可以将模型保存下来,然后在PyQt5应用程序中加载模型,并使用它进行手写数字的识别。当用户在画布上完成手写数字绘制后,我们可以将其送入已经训练好的CNN模型,然后让程序显示识别结果。 通过这样的方式,我们可以使用PyQt5TensorFlow-Keras开发手写数字识别应用程序,为用户提供更加便捷的数字识别方式。 ### 回答3: 手写数字识别是深度学习中的一个经典问题,利用人工神经网络或深度卷积神经网络可以达到很高的准确率。PyQt5是一个Python编写GUI库,可以将深度学习算法应用到用户友好的界面中,同时TensorFlow-Keras是一个强大的深度学习框架,利用它可以快速搭建一个卷积神经网络。 首先,我们需要准备手写数字数据集,比如MNIST数据集。我们可以使用Keras自带的数据集接口进行加载。然后,通过PyQt5绘制一个界面,使得用户可以在界面上进行手写数字输入。手写数字数据可以通过鼠标或触控板进行输入,我们可以将手写数字截图并进行处理,可以使用 PIL 库或 OpenCV 进行图片处理,将图片大小调整为合适的大小。接着,我们需要将图片输入到卷积神经网络中进行预测。我们可以使用TensorFlow-Keras搭建一个卷积神经网络模型,并把刚刚处理好的图片输入到模型中,进行预测。最后,我们可以在界面上输出预测结果,告诉用户识别的数字是什么。 总之,借助PyQt5TensorFlow-Keras的强大功能,我们可以轻松地设计一个手写数字识别的应用程序。但是需要注意的是,要精度高的数字识别需要使用比较深的卷积神经网络模型,并花费更多的时间来训练和调优模型。
评论 52
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值