yolov5训练测试转格式的pyqt界面

在这里插入代码片
```# -*- coding: utf-8 -*-
import os
import subprocess
from datetime import datetime

import yaml
# Form implementation generated from reading ui file 'single-train.ui'
#
# Created by: PyQt5 UI code generator 5.15.9
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again.  Do not edit this file unless you know what you are doing.


from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtCore import QProcess

from base_utils.rename_class import LabelRenamer
from base_utils.split_dataset import DataSplitter
from base_utils.xmlTotxt import VOCtoYOLOConverter


class Ui_MainWindow(object):
    def setupUi(self, MainWindow):
        MainWindow.setObjectName("MainWindow")
        MainWindow.resize(1032, 648)
        MainWindow.setStyleSheet("background-image: url(bg.jpg);")
        self.centralwidget = QtWidgets.QWidget(MainWindow)
        self.centralwidget.setObjectName("centralwidget")
        self.label = QtWidgets.QLabel(self.centralwidget)
        self.label.setGeometry(QtCore.QRect(230, 0, 121, 41))
        font = QtGui.QFont()
        font.setFamily("微软雅黑")
        font.setPointSize(22)
        self.label.setFont(font)
        self.label.setObjectName("label")
        self.pushButton = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton.setGeometry(QtCore.QRect(24, 92, 81, 31))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(12)
        self.pushButton.setFont(font)
        self.pushButton.setObjectName("pushButton")
        self.pushButton.clicked.connect(self.train)
        self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton_2.setGeometry(QtCore.QRect(24, 202, 81, 31))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(12)
        self.pushButton_2.setFont(font)
        self.pushButton_2.setObjectName("pushButton_2")
        self.pushButton_2.clicked.connect(self.test)
        self.label_3 = QtWidgets.QLabel(self.centralwidget)
        self.label_3.setGeometry(QtCore.QRect(130, 180, 71, 31))
        font = QtGui.QFont()
        font.setFamily("微软雅黑")
        font.setPointSize(12)
        self.label_3.setFont(font)
        self.label_3.setObjectName("label_3")
        self.model_path = QtWidgets.QTextEdit(self.centralwidget)
        self.model_path.setGeometry(QtCore.QRect(220, 180, 311, 31))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(11)
        self.model_path.setFont(font)
        self.model_path.setObjectName("model_path")
        self.label_4 = QtWidgets.QLabel(self.centralwidget)
        self.label_4.setGeometry(QtCore.QRect(130, 230, 71, 31))
        font = QtGui.QFont()
        font.setFamily("微软雅黑")
        font.setPointSize(12)
        self.label_4.setFont(font)
        self.label_4.setObjectName("label_4")
        self.test_image = QtWidgets.QTextEdit(self.centralwidget)
        self.test_image.setGeometry(QtCore.QRect(220, 230, 311, 31))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(11)
        self.test_image.setFont(font)
        self.test_image.setObjectName("test_image")
        self.label_5 = QtWidgets.QLabel(self.centralwidget)
        self.label_5.setGeometry(QtCore.QRect(20, 320, 81, 31))
        font = QtGui.QFont()
        font.setFamily("微软雅黑")
        font.setPointSize(12)
        self.label_5.setFont(font)
        self.label_5.setObjectName("label_5")
        self.log_info = QtWidgets.QTextBrowser(self.centralwidget)
        self.log_info.setGeometry(QtCore.QRect(10, 360, 1011, 241))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(11)
        self.log_info.setFont(font)
        self.log_info.setObjectName("log_info")
        self.label_6 = QtWidgets.QLabel(self.centralwidget)
        self.label_6.setGeometry(QtCore.QRect(690, 0, 121, 41))
        font = QtGui.QFont()
        font.setFamily("微软雅黑")
        font.setPointSize(22)
        self.label_6.setFont(font)
        self.label_6.setObjectName("label_6")
        self.groupBox = QtWidgets.QGroupBox(self.centralwidget)
        self.groupBox.setGeometry(QtCore.QRect(10, 50, 541, 271))
        self.groupBox.setTitle("")
        self.groupBox.setObjectName("groupBox")
        self.label_7 = QtWidgets.QLabel(self.groupBox)
        self.label_7.setGeometry(QtCore.QRect(120, 70, 71, 31))
        font = QtGui.QFont()
        font.setFamily("微软雅黑")
        font.setPointSize(12)
        self.label_7.setFont(font)
        self.label_7.setObjectName("label_7")
        self.input_root = QtWidgets.QTextEdit(self.groupBox)
        self.input_root.setGeometry(QtCore.QRect(210, 70, 311, 31))
        self.input_root.setObjectName("input_root")
        self.camer_num = QtWidgets.QTextEdit(self.groupBox)
        self.camer_num.setGeometry(QtCore.QRect(210, 20, 311, 31))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(11)
        self.camer_num.setFont(font)
        self.camer_num.setObjectName("camer_num")
        self.label_2 = QtWidgets.QLabel(self.groupBox)
        self.label_2.setGeometry(QtCore.QRect(120, 20, 71, 31))
        font = QtGui.QFont()
        font.setFamily("微软雅黑")
        font.setPointSize(12)
        self.label_2.setFont(font)
        self.label_2.setObjectName("label_2")
        self.pushButton_3 = QtWidgets.QPushButton(self.groupBox)
        self.pushButton_3.setGeometry(QtCore.QRect(10, 230, 81, 31))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(12)
        self.pushButton_3.setFont(font)
        self.pushButton_3.setObjectName("pushButton_3")
        self.pushButton_3.clicked.connect(self.pt2onnx)
        self.label_8 = QtWidgets.QLabel(self.groupBox)
        self.label_8.setGeometry(QtCore.QRect(120, 230, 71, 31))
        font = QtGui.QFont()
        font.setFamily("微软雅黑")
        font.setPointSize(12)
        self.label_8.setFont(font)
        self.label_8.setObjectName("label_8")
        self.any_model_path = QtWidgets.QTextEdit(self.groupBox)
        self.any_model_path.setGeometry(QtCore.QRect(210, 230, 311, 31))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(11)
        self.camer_num.setFont(font)
        self.any_model_path.setObjectName("any_model_path")
        self.test_result = QtWidgets.QGroupBox(self.centralwidget)
        self.test_result.setGeometry(QtCore.QRect(560, 50, 461, 271))
        self.test_result.setTitle("")
        self.test_result.setObjectName("test_result")
        self.groupBox.raise_()
        self.label.raise_()
        self.pushButton.raise_()
        self.pushButton_2.raise_()
        self.label_3.raise_()
        self.model_path.raise_()
        self.label_4.raise_()
        self.test_image.raise_()
        self.label_5.raise_()
        self.log_info.raise_()
        self.label_6.raise_()
        self.test_result.raise_()
        MainWindow.setCentralWidget(self.centralwidget)
        self.menubar = QtWidgets.QMenuBar(MainWindow)
        self.menubar.setGeometry(QtCore.QRect(0, 0, 1032, 23))
        self.menubar.setObjectName("menubar")
        MainWindow.setMenuBar(self.menubar)
        self.statusbar = QtWidgets.QStatusBar(MainWindow)
        self.statusbar.setObjectName("statusbar")
        MainWindow.setStatusBar(self.statusbar)

        self.retranslateUi(MainWindow)
        QtCore.QMetaObject.connectSlotsByName(MainWindow)

        self.process = QProcess()
        self.process.setProcessChannelMode(QProcess.MergedChannels)

    def retranslateUi(self, MainWindow):
        _translate = QtCore.QCoreApplication.translate
        MainWindow.setWindowTitle(_translate("MainWindow", "YOLOV5"))
        self.label.setText(_translate("MainWindow", "<html><head/><body><p>目标检测</p></body></html>"))
        self.pushButton.setText(_translate("MainWindow", "训练"))
        self.pushButton_2.setText(_translate("MainWindow", "测试"))
        self.label_3.setText(_translate("MainWindow", "模型路径:"))
        self.label_4.setText(_translate("MainWindow", "图像路径:"))
        self.label_5.setText(_translate("MainWindow", "状态"))
        self.label_6.setText(_translate("MainWindow", "测试结果"))
        self.label_7.setText(_translate("MainWindow", "输入路径:"))
        self.label_2.setText(_translate("MainWindow", "相机编号"))
        self.pushButton_3.setText(_translate("MainWindow", "格式转换"))
        self.label_8.setText(_translate("MainWindow", "模型路径:"))

    def preprocess(self):
        rootdir = self.input_root.toPlainText()
        img_path = rootdir + r'\images'
        lab_path = rootdir + r'\labels'
        print('xml转txt')
        converter = VOCtoYOLOConverter(lab_path)
        converter.convert_all()
        print('检查类别')
        renamer = LabelRenamer(lab_path, 0)
        renamer.rename_class()
        print("划分数据")
        splitter = DataSplitter(img_path, lab_path, rootdir)
        splitter.split_data(train_rate=0.85, val_rate=0.15, test_rate=0, need_unlabels=True)
        print('处理完成')

    def train(self):
        rootdir = self.input_root.toPlainText()
        img_path = rootdir + r'\images'
        lab_path = rootdir + r'\labels'
        print('xml转txt')
        converter = VOCtoYOLOConverter(lab_path)
        converter.convert_all()
        print('检查类别')
        renamer = LabelRenamer(lab_path, 0)
        renamer.rename_class()
        print("划分数据")
        splitter = DataSplitter(img_path, lab_path, rootdir)
        splitter.split_data(train_rate=0.85, val_rate=0.15, test_rate=0, need_unlabels=True)
        print('处理完成')
        # 训练状态与日志
        # self.process.readyReadStandardOutput.connect(self.process)
        self.process.readyReadStandardOutput.connect(self.StartTrain)
        self.process.readyReadStandardOutput.connect(self.Loging)
        # 获取当前脚本所在的目录
        script_directory = os.path.dirname(os.path.abspath(__file__))
        # 创建子目录的完整路径
        target_directory = os.path.join(script_directory, 'data')
        # 创建目录
        if not os.path.exists(target_directory):
            os.makedirs(target_directory)
        # 创建文件的绝对路径
        camera_num = self.camer_num.toPlainText()
        yaml_file_path = os.path.join(target_directory, f"base{camera_num}.yaml")
        # YAML内容
        now_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        rootdir = self.input_root.toPlainText()

        yaml_data = {
            'path': rootdir,
            'train': 'train/images',
            'val': 'valid/images',
            'nc': 1,
            'names': ['defect'],
            'time': now_time
        }
        # 写入YAML文件
        with open(yaml_file_path, 'w') as file:
            yaml.dump(yaml_data, file, default_flow_style=False)
        print(f"{yaml_file_path} 创建完成")
        train_path = 'runs/train'
        self.process.start('python train.py'
                           f" --data {yaml_file_path}"
                           f" --name {camera_num}"
                           f" --project {train_path}")
        self.process.finished.connect(self.FinishTrain)

    def test(self):
        test_path = self.test_image.toPlainText()
        model_path = self.model_path.toPlainText()
        camera_num = self.camer_num.toPlainText()
        detect_path = 'runs/detect'
        # 测试状态与日志
        self.process.readyReadStandardOutput.connect(self.StartTest)
        self.process.readyReadStandardOutput.connect(self.Loging)
        self.process.start(' python detect.py'
                           f" --weights {model_path}"
                           f" --source {test_path}"
                           f" --name base{camera_num}"
                           f" --project {detect_path}")
        self.process.finished.connect(self.FinishTest)

    def pt2onnx(self):
        model_path = self.any_model_path.toPlainText()
        self.process.readyReadStandardOutput.connect(self.StartConvert)
        self.process.readyReadStandardOutput.connect(self.Loging)
        self.process.start('python export.py'
                           f" --weights {model_path}")
        self.process.finished.connect(self.FinishConvert)

    def Loging(self):
        data = self.process.readAllStandardOutput().data().decode('utf-8')
        self.log_info.append(data)

    def StartTrain(self):
        self.label_5.setText('训练开始')

    def FinishTrain(self):
        self.label_5.setText('训练完成')

    def StartTest(self):
        self.label_5.setText('测试开始')

    def FinishTest(self):
        self.label_5.setText('测试完成')

    def StartConvert(self):
        self.label_5.setText('转换开始')

    def FinishConvert(self):
        self.label_5.setText('转换完成')


if __name__ == "__main__":
    import sys

    app = QtWidgets.QApplication(sys.argv)
    MainWindow = QtWidgets.QMainWindow()
    ui = Ui_MainWindow()
    ui.setupUi(MainWindow)
    MainWindow.show()
    sys.exit(app.exec_())


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值