用 pyqt5给深度学习目标检测+跟踪(yolov3+siamrpn)搭建界面(2)

在上次的基础上,重新布局,加入了许多内容,除了上篇文章提到的显示帧率以及检测的内容,还加入了目标跟踪以及对服务端的通信连接。

最终效果图:
在这里插入图片描述下面的程序的qt_designer产生的ui程序

# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'test3.ui'
#
# Created by: PyQt5 UI code generator 5.9.2
#
# WARNING! All changes made in this file will be lost!
import sys
from PyQt5.QtWidgets import *
from PyQt5 import QtCore, QtGui, QtWidgets


class EmittingStream(QtCore.QObject):
    textWritten = QtCore.pyqtSignal(str)  # 定义一个发送str的信号

    def write(self, text):
        self.textWritten.emit(str(text))


class Ui_MainWindow(object):
    def setupUi(self, MainWindow):
        MainWindow.setObjectName("MainWindow")
        MainWindow.resize(953, 954)
        self.centralwidget = QtWidgets.QWidget(MainWindow)
        self.centralwidget.setObjectName("centralwidget")

        self.cam_frame = QtWidgets.QFrame(self.centralwidget)
        self.cam_frame.setGeometry(QtCore.QRect(10, 10, 461, 441))
        self.cam_frame.setFrameShape(QtWidgets.QFrame.StyledPanel)
        self.cam_frame.setFrameShadow(QtWidgets.QFrame.Raised)
        self.cam_frame.setObjectName("cam_frame")
        self.cam_label = QtWidgets.QLabel(self.cam_frame)
        self.cam_label.setGeometry(QtCore.QRect(10, 10, 441, 421))
        self.cam_label.setObjectName("cam_label")

        self.detect_frame = QtWidgets.QFrame(self.centralwidget)
        self.detect_frame.setGeometry(QtCore.QRect(10, 460, 461, 451))
        self.detect_frame.setFrameShape(QtWidgets.QFrame.StyledPanel)
        self.detect_frame.setFrameShadow(QtWidgets.QFrame.Raised)
        self.detect_frame.setObjectName("detect_frame")
        self.detect_label = QtWidgets.QLabel(self.detect_frame)
        self.detect_label.setGeometry(QtCore.QRect(10, 10, 441, 431))
        self.detect_label.setObjectName("detect_label_2")

        self.bn_frame = QtWidgets.QFrame(self.centralwidget)
        self.bn_frame.setGeometry(QtCore.QRect(480, 460, 461, 451))
        self.bn_frame.setFrameShape(QtWidgets.QFrame.StyledPanel)
        self.bn_frame.setFrameShadow(QtWidgets.QFrame.Raised)
        self.bn_frame.setObjectName("bn_frame")
        self.text_label = QtWidgets.QLabel(self.bn_frame)
        self.text_label.setGeometry(QtCore.QRect(10, 10, 441, 431))
        self.text_label.setObjectName("text_label")

        self.textEdit = QtWidgets.QTextEdit(self.bn_frame)
        self.textEdit.setGeometry(QtCore.QRect(10, 90, 450, 350))
        self.textEdit.setObjectName("textEdit")
        self.textEdit.setReadOnly(True)
        # 下面将输出重定向到textEdit中
        sys.stdout = EmittingStream(textWritten=self.outputWritten)
        sys.stderr = EmittingStream(textWritten=self.outputWritten)


        self.track_frame = QtWidgets.QFrame(self.centralwidget)
        self.track_frame.setGeometry(QtCore.QRect(480, 10, 461, 441))
        self.track_frame.setFrameShape(QtWidgets.QFrame.StyledPanel)
        self.track_frame.setFrameShadow(QtWidgets.QFrame.Raised)
        self.track_frame.setObjectName("track_frame")
        self.track_label = QtWidgets.QLabel(self.track_frame)
        self.track_label.setGeometry(QtCore.QRect(10, 10, 441, 431))
        self.track_label.setObjectName("track_label")

        self.widget = QtWidgets.QWidget(self.bn_frame)
        self.widget.setGeometry(QtCore.QRect(10, 20, 382, 62))
        self.widget.setObjectName("widget")
        self.gridLayout = QtWidgets.QGridLayout(self.widget)
        self.gridLayout.setContentsMargins(0, 0, 0, 0)
        self.gridLayout.setSpacing(20)
        self.gridLayout.setObjectName("gridLayout")

        self.cam_bn = QtWidgets.QPushButton(self.widget)
        self.cam_bn.setObjectName("cam_bn")
        self.gridLayout.addWidget(self.cam_bn, 0, 0, 1, 1)

        self.load_detect_model_bn = QtWidgets.QPushButton(self.widget)
        self.load_detect_model_bn.setObjectName("load_detect_model_bn ")
        self.gridLayout.addWidget(self.load_detect_model_bn, 0, 1, 1, 1)

        self.load_cfg_bn = QtWidgets.QPushButton(self.widget)
        self.load_cfg_bn.setObjectName("load_cfg_bn")
        self.gridLayout.addWidget(self.load_cfg_bn, 0, 2, 1, 1)

        self.detect_bn = QtWidgets.QPushButton(self.widget)
        self.detect_bn.setObjectName("detect_bn")
        self.gridLayout.addWidget(self.detect_bn, 0, 3, 1, 1)

        self.load_track_model_bn = QtWidgets.QPushButton(self.widget)
        self.load_track_model_bn .setObjectName("load_track_model_bn ")
        self.gridLayout.addWidget(self.load_track_model_bn, 1, 0, 1, 1)

        self.track_bn = QtWidgets.QPushButton(self.widget)
        self.track_bn.setObjectName("track_bn")
        self.gridLayout.addWidget(self.track_bn, 1, 1, 1, 1)

        self.connet_server_bn = QtWidgets.QPushButton(self.widget)
        self.connet_server_bn .setObjectName("connet_server_bn ")
        self.gridLayout.addWidget(self.connet_server_bn, 1, 2, 1, 1)

        self.exit_bn = QtWidgets.QPushButton(self.widget)
        self.exit_bn.setObjectName("exit_bn")
        self.gridLayout.addWidget(self.exit_bn, 1, 3, 1, 1)

        MainWindow.setCentralWidget(self.centralwidget)
        self.menubar = QtWidgets.QMenuBar(MainWindow)
        self.menubar.setGeometry(QtCore.QRect(0, 0, 953, 17))
        self.menubar.setObjectName("menubar")
        MainWindow.setMenuBar(self.menubar)
        self.statusbar = QtWidgets.QStatusBar(MainWindow)
        self.statusbar.setObjectName("statusbar")
        MainWindow.setStatusBar(self.statusbar)

        self.retranslateUi(MainWindow)
        QtCore.QMetaObject.connectSlotsByName(MainWindow)

    def retranslateUi(self, MainWindow):
        _translate = QtCore.QCoreApplication.translate
        MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
        self.cam_label.setText(_translate("MainWindow", "原始画面"))
        self.detect_label.setText(_translate("MainWindow", "检测画面"))
        self.track_label.setText(_translate("MainWindow", "跟踪画面"))
        self.text_label.setText(_translate("MainWindow", "文本显示"))
        self.cam_bn.setText(_translate("MainWindow", "打开摄像头"))
        self.load_detect_model_bn.setText(_translate("MainWindow", "加载检测模型"))
        self.load_cfg_bn.setText(_translate("MainWindow", "加载cfg文件"))
        self.detect_bn.setText(_translate("MainWindow", "开始检测"))
        self.load_track_model_bn.setText(_translate("MainWindow", "加载跟踪模型"))
        self.track_bn.setText(_translate("MainWindow", "开始跟踪"))
        self.connet_server_bn.setText(_translate("MainWindow", "连接无人机"))
        self.exit_bn.setText(_translate("MainWindow", "退出"))


if __name__ == '__main__':
    app = QApplication(sys.argv)
    mainWindow = QMainWindow()
    ui = Ui_MainWindow()
    # 向主窗口上添加控件
    ui.setupUi(mainWindow)
    mainWindow.show()
    sys.exit(app.exec_())

下面的应用的主程序

from project_demo.test.test3 import Ui_MainWindow
from PyQt5.QtCore import QTimer, QCoreApplication
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
from pyqt5.yolov3.util import *
from pyqt5.yolov3.cam_demo import arg_parse, get_test_input, write
from pyqt5.yolov3.darknet import Darknet
from pyqt5.yolov3.preprocess import prep_image
import time
import cv2
import torch
import socket
import threading
# from project_demo.test import so_client
from pyqt5.pysot.pysot.core.config import cfg
from pyqt5.pysot.pysot.models.model_builder import ModelBuilder
from pyqt5.pysot.pysot.tracker.tracker_builder import build_tracker


class CamShow(QMainWindow, Ui_MainWindow):
    def __del__(self):
        try:
            self.camera.release()  # 释放资源
        except:
            return

    def __init__(self, parent=None):
        super(CamShow, self).__init__(parent)
        self.setupUi(self)
        self.PrepWidgets()  # PrepWidgets用来初始化各个控件,由于某些控件需要在其它控件背后的函数运行后才能工作,因此需要先将它“disable”,
        # 例如,“保存”功能,必须在我们点击“开始”按钮开始取图后才有效,因此在程序最开始的时候,需要将“保存”按钮“disable”;
        # self.PrepParameters()  # PrepParameters用来定义并初始化程序运行过程中会用到的变量
        self.CallBackFunctions()  # CallBackFunctions则是各个控件背后的功能函数的集合,它定义了我们在程序界面上进行某项操作后实际执行的代码
        # 计时器的定义和调用,它定义了一个定时器,
        # 等我们执行计时器开始的代码后,该计时器就开始计时,每次计时结束都会调用一次函数TimerOutFun,我们用计时器实现对摄像头图像的循环读取和显示。
        self.PrepParameters()
        self.timer = QTimer(self)
        self.timer.timeout.connect(self.TimerOutFun)

    def PrepWidgets(self):
        self.PrepCamera()
        self.load_detect_model_bn.setEnabled(False)
        self.load_track_model_bn.setEnabled(False)
        self.load_cfg_bn.setEnabled(False)
        self.detect_bn.setEnabled(False)
        self.track_bn.setEnabled(False)
        self.load_track_model_bn.setEnabled(False)
        self.connet_server_bn.setEnabled(False)

    def PrepCamera(self):
        self.camera = cv2.VideoCapture(0)  # 调用OpenCV的VideoCapture函数打开摄像头

    def PrepParameters(self):
        self.first_frame = True


    def CallBackFunctions(self):
        self.cam_bn.clicked.connect(self.StartCamera)
        self.load_detect_model_bn.clicked.connect(self.open_detect_model)
        self.load_cfg_bn.clicked.connect(self.open_cfg)
        self.detect_bn.clicked.connect(self.detect)
        self.load_track_model_bn.clicked.connect(self.open_track_model)
        self.track_bn.clicked.connect(self.object_tracking)
        self.connet_server_bn.clicked.connect(self.socket_open_tcpc)
        self.exit_bn.clicked.connect(self.ExitApp)
    #
    def StartCamera(self):
        self.cam_bn.setEnabled(False)
        self.load_detect_model_bn.setEnabled(True)
        self.load_cfg_bn.setEnabled(True)
        self.detect_bn.setEnabled(True)
        self.track_bn.setEnabled(True)
        self.load_track_model_bn.setEnabled(True)
        self.connet_server_bn.setEnabled(True)
        self.timer.start(1)
        #self.timelb = time.clock()

    def outputWritten(self, text):
        cursor = self.textEdit.textCursor()
        cursor.movePosition(QtGui.QTextCursor.End)
        cursor.insertText(text)
        self.textEdit.setTextCursor(cursor)
        self.textEdit.ensureCursorVisible()


    def open_detect_model(self):
        global openfile_name_mdoel
        openfile_name_mdoel, _ = QFileDialog.getOpenFileName(self.load_detect_model_bn, '选择检测模型',
                                                             '/home/ljw/桌面/demo/project_demo/pyqt5/yolov3/')
        print('加载模型文件地址为:' + str(openfile_name_mdoel))

    def open_cfg(self):
        global openfile_name_cfg
        openfile_name_cfg, _ = QFileDialog.getOpenFileName(self.load_cfg_bn, '选择cfg文件',
                                                           '/home/ljw/桌面/demo/project_demo/pyqt5/yolov3/')
        print('加载cfg文件地址为:' + str(openfile_name_cfg))

    def detect(self):
        self.frames = 0
        self.start = time.time()
        cfgfile = openfile_name_cfg
        weightsfile = openfile_name_mdoel
        self.num_classes = 80
        args = arg_parse()
        self.confidence = float(args.confidence)
        self.nms_thesh = float(args.nms_thresh)
        self.CUDA = torch.cuda.is_available()
        self.model = Darknet(cfgfile)
        self.model.load_weights(weightsfile)
        self.model.net_info["height"] = args.reso
        self.inp_dim = int(self.model.net_info["height"])
        assert self.inp_dim % 32 == 0
        assert self.inp_dim > 32
        self.timerdec = QtCore.QTimer()
        self.timerdec.start()
        self.timerdec.setInterval(3)  # 0.1s刷新一次
        self.timerdec.timeout.connect(self.object_detection)

    def object_detection(self):
        if self.CUDA:
            self.model.cuda()
        self.model.eval()
        img, orig_im, dim = prep_image(self.Image, self.inp_dim)
        output = self.model(Variable(img), self.CUDA)
        output = write_results(output, self.confidence, self.num_classes, nms=True, nms_conf=self.nms_thesh)
        output[:, 1:5] = torch.clamp(output[:, 1:5], 0.0, float(self.inp_dim)) / self.inp_dim
        output[:, [1, 3]] *= self.Image.shape[1]
        output[:, [2, 4]] *= self.Image.shape[0]
        list(map(lambda x: write(x, orig_im), output))
        self.frames += 1
        print("FPS of the video is {:5.2f}".format(self.frames / (time.time() - self.start)))
        camimg = cv2.cvtColor(orig_im, cv2.COLOR_BGR2RGB)
        showImage = QtGui.QImage(camimg.data, camimg.shape[1], camimg.shape[0], QtGui.QImage.Format_RGB888)
        self.detect_label.setPixmap(QtGui.QPixmap.fromImage(showImage))
        QApplication.processEvents()


    def open_track_model(self):
        global openfile_name_track_model
        openfile_name_track_model, _ = QFileDialog.getOpenFileName(self.load_track_model_bn, '选择跟踪模型',
                 '/home/ljw/桌面/demo/project_demo/pyqt5/pysot/experiments/siamrpn_alex_dwxcorr/model.pth')
        print('加载跟踪模型地址为:' + str(openfile_name_track_model))

    def object_tracking(self):
        self.timerdec.stop()
        self.frames = 0
        self.start = time.time()
        config = '/home/ljw/桌面/demo/project_demo/pyqt5/pysot/experiments/siamrpn_alex_dwxcorr/config.yaml'
        cfg.merge_from_file(config)
        cfg.CUDA = torch.cuda.is_available()  # cfg.CUDA = False
        device = torch.device('cuda' if cfg.CUDA else 'cpu')  # device  = cpu
        self.model = ModelBuilder()
        # load model
        self.model.load_state_dict(torch.load(openfile_name_track_model,
                                          map_location=lambda storage, loc: storage.cpu()))
        self.model.eval().to(device)
        # build tracker
        self.tracker = build_tracker(self.model)
        self.timerdeck = QtCore.QTimer()
        self.timerdeck.start()
        self.timerdeck.setInterval(3)  # 0.1s刷新一次
        self.timerdeck.timeout.connect(self.object_tracking_test)

    def object_tracking_test(self):
        # for frame in get_frames(args.video_name):
        if self.first_frame:
            init_rect = [10, 30, 200, 200]
            self.tracker.init(self.Image, init_rect)
            self.first_frame = False

        else:
            outputs = self.tracker.track(self.Image)
            self.frames += 1
            end = time.time()
            print("FPS of the video is {:5.2f}".format(self.frames / (end - self.start)))
            if 'polygon' in outputs:
                print("test")
                polygon = np.array(outputs['polygon']).astype(np.int32)
                cv2.polylines(self.Image, [polygon.reshape((-1, 1, 2))],
                              True, (0, 255, 0), 3)
                mask = ((outputs['mask'] > cfg.TRACK.MASK_THERSHOLD) * 255)
                mask = mask.astype(np.uint8)
                mask = np.stack([mask, mask * 255, mask]).transpose(1, 2, 0)
                self.Image = cv2.addWeighted(self.Image, 0.77, mask, 0.23, -1)
            else:
                bbox = list(map(int, outputs['bbox']))
                cv2.rectangle(self.Image, (bbox[0], bbox[1]),
                              (bbox[0] + bbox[2], bbox[1] + bbox[3]),
                              (0, 255, 0), 3)
            img = cv2.cvtColor(self.Image, cv2.COLOR_BGR2RGB)
            show_Image = QtGui.QImage(img.data, img.shape[1], img.shape[0], QtGui.QImage.Format_RGB888)
            self.track_label.setPixmap(QtGui.QPixmap.fromImage(show_Image))
            QApplication.processEvents()

    def DispImg(self):
        self.img = cv2.cvtColor(self.Image, cv2.COLOR_BGR2RGB)
        showImage = QtGui.QImage(self.img.data, self.img.shape[1], self.img.shape[0], QtGui.QImage.Format_RGB888)
        self.cam_label.setPixmap(QtGui.QPixmap.fromImage(showImage))

    def socket_open_tcpc(self):
        self.client_th = threading.Thread(target=self.tcp_client_concurrency)
        # 设置线程为守护线程,防止退出主线程时,子线程仍在运行
        self.client_th.setDaemon(True)
        # 新线程启动
        self.client_th.start()

    def tcp_client_concurrency(self):
        print("1")
        client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        host = "10.23.21.69"
        port = 5033
        client.connect((host, port))
        send_msg = "abcd"
        while True:
            client.send(send_msg.encode("utf-8"))
            msg = client.recv(1024)
            print(msg.decode("utf-8"))
        client.close()

    def ExitApp(self):
        self.timer.stop()
        self.camera.release()
        # self.MsgTE.setPlainText('Exiting the application..')
        QCoreApplication.quit()

    def TimerOutFun(self):
        success, img = self.camera.read()
        if success:
            self.Image = img
            self.DispImg()


if __name__ == '__main__':
    app = QApplication(sys.argv)
    ui = CamShow()
    ui.show()
    sys.exit(app.exec_())

在与服务器之间建立通信的时候,如果没有开辟新的线程,会造成窗口无响应。
具体解决办法可以参考
PyQt5+socket编程界面卡住未响应
该程序后面还需要完善,首先是目标跟踪的初始框是人为给定,需要从目标检测中,选择一个方框作为初始框。另外,与服务端建立通信之后,需要发送一些数据,这些数据是什么?该怎么提取这些数据,怎么发送,下篇文章会陆续解决。

  • 4
    点赞
  • 61
    收藏
    觉得还不错? 一键收藏
  • 24
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 24
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值