离线中英文文字图像匹配度检测软件(基于CLIP、Transformers等实现)

文字图像匹配度检测软件(基于CLIP、Transformers等实现)

  • 使用CLIP(对比图文预训练方法)提供的图文匹配度检测接口,使用huggingface基于Transformers的机器模型实现离线翻译,因此输入中英文均可检测。前端图形化界面使用PYQT开发,并使用了qdarkstyle进行优化,具体效果如下图所示:
  • 左边一栏是候选文字语句,右边一栏是对应每条文字语句的匹配度 支持中英文
    在这里插入图片描述
  • 点击选择图片,如为我代码中附带的数据集中的图片,那么右边第一列的第一行会附上这张图片的正确描述,如为其他图片,则可以手动输入正确描述,随机抽取中文、英文按钮会下后四行抽取干扰的中文、英文描述,所有候选的5个描述语句均可手动修改。
    下面是一个例子:
    在这里插入图片描述
  • 从结果可以看出,模型对于最贴合图片的那句描述是可以正确识别的,而且效果很好,支持中英文,我自己也做了很多实验测试,代码中也有评估模型准确度的代码testCode.py。

主体代码如下,其余代码以及requirements等打包放在我的资源中,可以下载运行:

# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'txtimgui.ui'
#
# Created by: PyQt5 UI code generator 5.15.4
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again.  Do not edit this file unless you know what you are doing.
import random
import os
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import QFileDialog
import torch
from PIL import Image
import translate_main
import clip
import warnings

warnings.filterwarnings("ignore")
global imgNamePath


def getPicName(myLine):
    resName = ''
    if "#enc#0 " in myLine:
        resName = myLine.split("#enc#0 ")[0]
    elif "#zhc#1 " in myLine:
        resName = myLine.split("#zhc#1 ")[0]
    else:
        resName = myLine.split("#zhc#0 ")[0]
    return resName


def getPicSentence(myLine):
    resName = ''
    if "#enc#0 " in myLine:
        resName = myLine.split("#enc#0 ")[1]
    elif "#zhc#1 " in myLine:
        resName = myLine.split("#zhc#1 ")[1]
    else:
        resName = myLine.split("#zhc#0 ")[1]
    return resName


class Ui_MainWindow(object):
    def setupUi(self, MainWindow):
        MainWindow.setObjectName("MainWindow")
        MainWindow.resize(800, 600)
        MainWindow.setMinimumSize(QtCore.QSize(80, 30))
        self.centralwidget = QtWidgets.QWidget(MainWindow)
        self.centralwidget.setObjectName("centralwidget")
        self.pushButton = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton.setGeometry(QtCore.QRect(30, 90, 91, 31))
        self.pushButton.setObjectName("pushButton")
        self.pushButton.clicked.connect(self.openImage)
        self.label = QtWidgets.QLabel(self.centralwidget)
        self.label.setGeometry(QtCore.QRect(40, 160, 241, 271))
        self.label.setObjectName("label")
        self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
        self.lineEdit.setGeometry(QtCore.QRect(120, 90, 181, 31))
        self.lineEdit.setObjectName("lineEdit")
        self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton_2.setGeometry(QtCore.QRect(374, 362, 81, 31))
        self.pushButton_2.setObjectName("pushButton_2")
        self.pushButton_2.clicked.connect(self.randomExtract)
        self.pushButton_3 = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton_3.setGeometry(QtCore.QRect(514, 362, 81, 31))
        self.pushButton_3.setObjectName("pushButton_3")
        self.pushButton_3.clicked.connect(self.matching)
        self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton_5.setGeometry(QtCore.QRect(374, 462, 81, 31))
        self.pushButton_5.setObjectName("pushButton_5")
        self.pushButton_5.clicked.connect(self.randomExtractEn)
        self.widget = QtWidgets.QWidget(self.centralwidget)
        self.widget.setGeometry(QtCore.QRect(310, 100, 331, 221))
        self.widget.setObjectName("widget")
        self.verticalLayout = QtWidgets.QVBoxLayout(self.widget)
        self.verticalLayout.setContentsMargins(0, 0, 0, 0)
        self.verticalLayout.setObjectName("verticalLayout")
        self.lineEdit_2 = QtWidgets.QLineEdit(self.widget)
        self.lineEdit_2.setMinimumSize(QtCore.QSize(100, 30))
        self.lineEdit_2.setObjectName("lineEdit_2")
        self.verticalLayout.addWidget(self.lineEdit_2)
        self.lineEdit_3 = QtWidgets.QLineEdit(self.widget)
        self.lineEdit_3.setMinimumSize(QtCore.QSize(0, 30))
        self.lineEdit_3.setObjectName("lineEdit_3")
        self.verticalLayout.addWidget(self.lineEdit_3)
        self.lineEdit_4 = QtWidgets.QLineEdit(self.widget)
        self.lineEdit_4.setMinimumSize(QtCore.QSize(0, 30))
        self.lineEdit_4.setObjectName("lineEdit_4")
        self.verticalLayout.addWidget(self.lineEdit_4)
        self.lineEdit_5 = QtWidgets.QLineEdit(self.widget)
        self.lineEdit_5.setMinimumSize(QtCore.QSize(0, 30))
        self.lineEdit_5.setObjectName("lineEdit_5")
        self.verticalLayout.addWidget(self.lineEdit_5)
        self.lineEdit_6 = QtWidgets.QLineEdit(self.widget)
        self.lineEdit_6.setMinimumSize(QtCore.QSize(0, 30))
        self.lineEdit_6.setObjectName("lineEdit_6")
        self.verticalLayout.addWidget(self.lineEdit_6)
        self.widget1 = QtWidgets.QWidget(self.centralwidget)
        self.widget1.setGeometry(QtCore.QRect(650, 100, 135, 221))
        self.widget1.setObjectName("widget1")
        self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.widget1)
        self.verticalLayout_2.setContentsMargins(0, 0, 0, 0)
        self.verticalLayout_2.setObjectName("verticalLayout_2")
        self.lineEdit_7 = QtWidgets.QLineEdit(self.widget1)
        self.lineEdit_7.setMinimumSize(QtCore.QSize(0, 30))
        self.lineEdit_7.setObjectName("lineEdit_7")
        self.verticalLayout_2.addWidget(self.lineEdit_7)
        self.lineEdit_8 = QtWidgets.QLineEdit(self.widget1)
        self.lineEdit_8.setMinimumSize(QtCore.QSize(0, 30))
        self.lineEdit_8.setObjectName("lineEdit_8")
        self.verticalLayout_2.addWidget(self.lineEdit_8)
        self.lineEdit_9 = QtWidgets.QLineEdit(self.widget1)
        self.lineEdit_9.setMinimumSize(QtCore.QSize(0, 30))
        self.lineEdit_9.setObjectName("lineEdit_9")
        self.verticalLayout_2.addWidget(self.lineEdit_9)
        self.lineEdit_10 = QtWidgets.QLineEdit(self.widget1)
        self.lineEdit_10.setMinimumSize(QtCore.QSize(0, 30))
        self.lineEdit_10.setObjectName("lineEdit_10")
        self.verticalLayout_2.addWidget(self.lineEdit_10)
        self.lineEdit_11 = QtWidgets.QLineEdit(self.widget1)
        self.lineEdit_11.setMinimumSize(QtCore.QSize(0, 30))
        self.lineEdit_11.setObjectName("lineEdit_11")
        self.verticalLayout_2.addWidget(self.lineEdit_11)
        MainWindow.setCentralWidget(self.centralwidget)
        self.statusbar = QtWidgets.QStatusBar(MainWindow)
        self.statusbar.setObjectName("statusbar")
        MainWindow.setStatusBar(self.statusbar)
        self.retranslateUi(MainWindow)
        QtCore.QMetaObject.connectSlotsByName(MainWindow)

    def openImage(self):
        global imgNamePath
        # 这里为了方便别的地方引用图片路径,将其设置为全局变量
        # 弹出一个文件选择框,第一个返回值imgName记录选中的文件路径+文件名,第二个返回值imgType记录文件的类型
        # QFileDialog就是系统对话框的那个类第一个参数是上下文,第二个参数是弹框的名字,第三个参数是默认打开的路径,第四个参数是需要的格式
        # 设置try-except防止各种不符合要求的操作导致软件退出
        try:
            imgNamePath, imgType = QFileDialog.getOpenFileName(self.centralwidget, "选择图片",
                                                               './dataset',
                                                               "*.jpg;;*.png;;All Files(*)")
            # 通过文件路径获取图片文件,并设置图片长宽为label控件的长、宽
            img = QtGui.QPixmap(imgNamePath).scaled(self.label.width(), self.label.height())
            # 在label控件上显示选择的图片
            self.label.setPixmap(img)
        # 显示所选图片的路径
        except:
            return
        # print(imgNamePath)
        self.lineEdit.setText(imgNamePath)
        try:
            resPath = imgNamePath.split('image/')[1]
        except:
            return
        # 卫星
        for line in open("./dataset/militray_label.txt", encoding='utf-8'):
            if getPicName(line) == resPath:
                print(line)
                self.lineEdit_2.setText(getPicSentence(line))
        # 中文
        for line in open("./dataset/ch_label.txt", encoding='GBK'):
            if getPicName(line) == resPath:
                print(line)
                self.lineEdit_2.setText(getPicSentence(line))
        # 英文
        for line in open("./dataset/enc_label.txt", encoding='GBK'):
            if getPicName(line) == resPath:
                print(line)
                self.lineEdit_2.setText(getPicSentence(line))

    def randomExtract(self):
        # 随机抽取 图片名字和对应正确描述构成映射 读图片的时候把正确的那句话也放到第一个框里
        r1 = random.randint(10, 20)
        r2 = random.randint(21, 30)
        r3 = random.randint(31, 39)
        r4 = random.randint(40, 49)
        f = open("./dataset/militray_label.txt", encoding='utf=8')
        resList = []
        while 1:
            lines = f.readlines(10000)
            if not lines:
                break
            for line in lines:
                resList.append(getPicSentence(line))
        print(resList[r1], resList[r2], resList[r3], resList[r4])
        self.lineEdit_3.setText(resList[r1])
        self.lineEdit_4.setText(resList[r2])
        self.lineEdit_5.setText(resList[r3])
        self.lineEdit_6.setText(resList[r4])
        f.close()

    def randomExtractEn(self):
        # 随机抽取英文 图片名字和对应正确描述构成映射  读图片的时候把正确的那句话也放到第一个框里
        r1 = random.randint(10, 20)
        r2 = random.randint(21, 30)
        r3 = random.randint(31, 39)
        r4 = random.randint(40, 49)
        f = open("./dataset/militray_enc_label.txt", encoding='utf=8')
        resList = []
        while 1:
            lines = f.readlines(10000)
            if not lines:
                break
            for line in lines:
                resList.append(getPicSentence(line))
        print(resList[r1], resList[r2], resList[r3], resList[r4])
        self.lineEdit_3.setText(resList[r1])
        self.lineEdit_4.setText(resList[r2])
        self.lineEdit_5.setText(resList[r3])
        self.lineEdit_6.setText(resList[r4])
        f.close()

    def matching(self):
        t1 = self.lineEdit_2.text()
        t2 = self.lineEdit_3.text()
        t3 = self.lineEdit_4.text()
        t4 = self.lineEdit_5.text()
        t5 = self.lineEdit_6.text()
        s1, s2, s3, s4, s5 = translate_main.trans(t1, t2, t3, t4, t5)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model, preprocess = clip.load("ViT-B/32", device=device)
        global imgNamePath
        image = preprocess(Image.open(imgNamePath)).unsqueeze(0).to(device)
        text = clip.tokenize([str(s1), str(s2), str(s3), str(s4), str(s5)]).to(device)
        with torch.no_grad():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text)
            logits_per_image, logits_per_text = model(image, text)
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()
            print("文本图像匹配度:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]
            prob = str(probs)[2:-2]
            print(prob)
            t1, t2, t3, t4, t5 = prob.split()
            # 格式化输出 更好看
            # 使用python内置的round()函数
            # a = 1.1314 a = 1.0000 a = 1.1267
            # b = round(a.2)b = round(a.2)b = round(a.2)
            # output b = 1.13 output b = 1.0 output b = 1.13
            t1 = round(float(t1), 4)
            t2 = round(float(t2), 4)
            t3 = round(float(t3), 4)
            t4 = round(float(t4), 4)
            t5 = round(float(t5), 4)
            print(t1, t2, t3, t4, t5)
            self.lineEdit_7.setText(str(t1))
            self.lineEdit_8.setText(str(t2))
            self.lineEdit_9.setText(str(t3))
            self.lineEdit_10.setText(str(t4))
            self.lineEdit_11.setText(str(t5))
            # 下面为记录每次运行的结果
            # 英文测试
            with open('./testResult/enTestResult.txt', 'a+') as writers:
                # 中文测试
                # with open('./testResult/testResult.txt', 'a+') as writers:
                # 打开文件 ‘a+’ ==a+r(可追加可写,文件若不存在就创建)
                if t1 > 0.5:
                    a = imgNamePath
                    b = t1
                    c = 'True'
                    # 如果要按行写入,我们只需要再字符串开头或结尾添加换行符'\n'
                    # writers.write(a + '\n')
                    # 如果想要将多个变量同时写入一行中,可以使用writelines()函数,
                    # 要求将传入的变量写成一个list:
                    # writers.write('\n')
                    # writers.writelines([str(a), ',', str(b), ',', str(c)])
                    writers.write(str(a) + ',' + str(b) + ',' + str(c) + '\n')
                else:
                    res = 'False'
                    writers.write(res + '\n')
            # return t1, t2, t3, t4, t5

    def retranslateUi(self, MainWindow):
        _translate = QtCore.QCoreApplication.translate
        MainWindow.setWindowTitle(_translate("MainWindow", "文字图像匹配度检测"))
        self.pushButton.setText(_translate("MainWindow", "选择图片"))
        self.label.setText(_translate("MainWindow",
                                      "<html><head/><body><p><span style=\" font-size:14pt; font-weight:600;\">图文匹配</span></p></body></html>"))
        self.pushButton_2.setText(_translate("MainWindow", "随机抽取中文"))
        self.pushButton_3.setText(_translate("MainWindow", "开始检测"))
        self.pushButton_5.setText(_translate("MainWindow", "随机抽取英文"))
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zoe_ya

如果你成功申请,可以打赏杯奶茶

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值