【PyTorch + PyQt + 网络爬虫】—> 基于resnet34的中草药识别

原作者: Pytorch实现中药材(中草药)分类识别(含训练代码和数据集)
本人在原作者代码基础上进行改进,这里先给出原作者链接,以示尊重原作。

前言:基于原作者代码,经训练了mobilenet_v2、resnet18和resnet34三种模型效果后,本人选择识别准确率最高的resnet34进行识别。然通过网络爬虫爬取各类中草药功能信息。使用QT界面展示每识别一张图片输出对应的功能信息。

效果

在这里插入图片描述

resnet34

  • 直接调用pytorch预训练好的resnet34模型
from torchvision.models.resnet import model_urls
	backbone = models.resnet34(pretrained=False)

tensorboard可视化

  • 终端运行以下一行代码,点开输出的网址。
tensorboard --logdir=路径/resnet34_1.0_CrossEntropyLoss_20230524125658/log

mobilenet_v2:
在这里插入图片描述

resnet18:
在这里插入图片描述
resnet34:
在这里插入图片描述

  • 综上,resnet34的 train 准确率为99.28%, test 准确率为98.42。

网络爬虫

  • 通过搜索每一种中草药,爬取对应的功能信息。请求需间隔一定时间,不然请求会失败,可能出现阿贾克斯请求;同时在爬取时不要登录账号。爬取163类中草药需进行两次,爬到一半左右程序会报错中断,这时需将代码47行的for循环的索引进行对应的调整再继续爬取剩下的。
from bs4 import BeautifulSoup
import urllib.request
import re
import time

class ZYC():
    def __init__(self):
        # 伪装成浏览器访问,适用于拒绝爬虫的网站
        self.headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36 Edg/113.0.1774.50'}

    def gethtml(self, name):
        # 修改链接拼接方式
        full_path = "/zhongyaocai"
        params = {
            'keyword': name,
            'jc_me_name': 1,
            'me_pinyin': 1,
            'me_englishname': 1,
            'me_latinname': 1,
            'use': '',
            'hauxuechenhfen': '',
            'zhuyi': '',
            'yctype': '全部',
        }
        encoded_params = urllib.parse.urlencode(params)

        # 拼接完整的 URL,请参考修改的参数
        url = f'https://db.yaozh.com{full_path}?{encoded_params}'

        # 在创建 Request 对象时指定 headers 参数,而不是使用 encoding 参数
        self.headers['Content-Type'] = 'application/json;charset=utf-8'
        req = urllib.request.Request(url, headers=self.headers)

        req_timeout = 5
        response = urllib.request.urlopen(req, None, req_timeout)
        html = response.read().decode('utf-8')
        return html

    # 获取自己想要的内容
    def getinformation(self):
        # 读取中药名称
        words = []  # 存储txt中药材名称
        with open('class_names.txt', 'r') as class_name:
            for line in class_name:
                word = line.strip()  # strip()方法用于去除每行末尾的字符(默认是空格和换行符)
                words.append(word)  # 将当前行读到的单词添加到列表中
        for name in range(len(words)):
            html = self.gethtml(words[name])  # 传入药材名称
            reg = re.compile(r"</p>") # 利用正则表达式去除
            html = reg.sub('', html)
            reg = re.compile(r"<p>")
            html = reg.sub('', html)
            soup = BeautifulSoup(html, "html.parser") # 转换成 BeautifulSoup 对象
            Trlist = soup.find_all('tr')

            # 获取
            try:
                # # 获取标题
                # for item in Trlist[0]:  # 包含了第一页中所有的列名
                #     if item not in ['\n', '\t', ' ']:  # 检测换行符或空白符
                #         item = item.get_text(strip=True)  # 转换为纯文本格式
                #         with open("Chinese_herbal.txt", "a") as file:
                #             file.write(item + '|')  # 分隔符

                # 获取内容
                file = open("Chinese_herbal.txt", "a")
                file.write('\n')
                for te in Trlist[1]:
                    for item in te:
                        if item not in ['\n', ' ', '\s']:
                            item = item.get_text(strip=True)
                            reg = re.compile(r'\s+')  # 正则表达式匹配任意连续的空白符
                            item = reg.sub('', item)  # 替换为空
                            file.write(item + '|')
                file.close()
                print("--正在采集%s信息--" % words[name])
                time.sleep(5)  # 延迟程序执行的时间;等待5秒,然后再进行下一个网页的爬取。
            except IndexError:
                print(f"No data found for {words[name]}")


if __name__ == '__main__':
    ZYC().getinformation()

demo

代码逻辑:

导入库

import sys
sys.path.append("libs")
import argparse
from basetrainer.utils import setup_config
from pybaseutils import file_utils, image_utils
from classifier import inference
from PyQt5.QtWidgets import QApplication, QMainWindow
from ui import Ui_MainWindow
from PyQt5 import QtGui

再看主函数

if __name__ == "__main__":
    # 路径——>解析器——>预测分类——>实现
    parser = get_parser()
    # cfg配置对象;setup_config解析命令行参数(parser.parse_args():命令行参数解析器, cfg_updata使用默认值)。
    cfg = setup_config.parser_config(parser.parse_args(), cfg_updata=False)
    t = Predictor(cfg)
    # (image_dir图像目录,shuffle=False表示不要打乱目录中文件的顺序(默认为True,即打乱))
    t.image_dir_predict(cfg.image_dir, shuffle=False)

该步骤进行配置文件和训练好模型文件的调用

def get_parser():
    # 配置文件
    # config_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20230404151914/config.yaml"
    # # 模型文件
    # model_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20230404151914/model/best_model_116_98.4700.pth"
    # 配置文件
    config_file = "work_space/resnet34_1.0_CrossEntropyLoss_20230524125658/config.yaml"
    # 模型文件
    model_file = "work_space/resnet34_1.0_CrossEntropyLoss_20230524125658/model/best_model_114_98.5700.pth"

    # 待测试图片目录
    image_dir = "data/test_images"
    # 创建用于解析命令行参数的ArgumentParser对象;用于读取命令行输入,并将输入转换为其他数据类型以供程序使用。
    parser = argparse.ArgumentParser(description="Inference Argument")
    # 使用argparse库来解析命令行参数的函数。
    # ("-c", "--config_file":参数名称,分别代表短选项何长选项;
    # help=:该选项的描述信息;default=:不提供选项则使用默认值;type=:类型)
    parser.add_argument("-c", "--config_file", help="configs file", default=config_file, type=str)
    parser.add_argument("-m", "--model_file", help="model_file", default=model_file, type=str)
    parser.add_argument("--device", help="cuda device id", default="cuda:0", type=str)
    parser.add_argument("--image_dir", help="image file or directory", default=image_dir, type=str)
    return parser

再而进入识别的类中

class Predictor(inference.Inference):
    def __init__(self, cfg):
        super(Predictor, self).__init__(cfg)

    def predict(self, image):
        """预测类别"""
        # pred_index类别标签,pred_score置信度得分
        pred_index, pred_score = self.inference(image)
        # 将模型输出的类别pred_index使用self.label2class_name函数进行转换,返回对应的类别名称
        pred_index = self.label2class_name(pred_index)
        return pred_index, pred_score

然后进行图片的识别,以得到图片识别的类别和准确率

    def image_dir_predict(self, image_dir, vis=True, use_rgb=True, shuffle=False):
        """
        :param image_dir: list,*.txt ,图像路径或目录
        :param vis: 是否可视化
        :param use_rgb: 是否转换为RGB格式
        :param shuffle:  是否打乱顺序
        """
        pred_index_base = [] # 存储名称,穿进嵌套的类中
        formatted_string_base = [] # 存储准确率,穿进嵌套的类中
        image_list = file_utils.get_files_lists(image_dir, shuffle=shuffle) # 读取image_dir
        cnt = len(image_list) # 统计长度
        for path in image_list:
            image = image_utils.read_image_ch(path, use_rgb=use_rgb)  # 读取图像文件,对该文件进行解码并将其转换为numpy数组。
            pred_index, pred_score = self.predict(image)  # 预测类别
            value = float(pred_score[0])  # 准确率位数限制
            formatted_string = "%.5f" % value
            pred_index_base.append(pred_index[0])
            formatted_string_base.append(formatted_string)

这里嵌套QT界面的类,如果两个类合并则会出现冲突,所以在上一步将所有识别的图片名称、准确率存储进列表,然后在QT类中使用计数器原理进行调用,当按钮触发后才进入下一张图片的展示。

		# 嵌套QT界面
        if vis:  # 图像显示
            class MainWindow(QMainWindow, Ui_MainWindow):
                def __init__(self):
                    super().__init__()
                    # 设置UI
                    self.setupUi(self)
                    self.tmp = 0 # 计数器
                    self.pushButton.clicked.connect(self.change_image) # 按钮

                def change_image(self):
                    if self.tmp == cnt: # 达到最后一张照片则
                        sys.exit() # 关闭界面
                    else:
                        self.textBrowser.clear() # 清空上一次的信息
                        self.textBrowser_2.clear()
                        self.textBrowser_3.clear()
                        self.textBrowser_4.clear()
                        self.textBrowser_5.clear()
                        self.textBrowser_6.clear()
                        path = image_list[self.tmp]
                        # 将图片设置为QLabel的内容
                        self.label_2.setPixmap(QtGui.QPixmap(path)) # 图片
                        self.textBrowser.append(pred_index_base[self.tmp]) # 名称
                        self.textBrowser_2.append(str(formatted_string_base[self.tmp])) # 准确率
                        self.BinarySearch(pred_index_base[self.tmp]) # 详情
                        self.tmps() # 计数器

                def tmps(self):
                    self.tmp += 1
                # 信息输出
                def BinarySearch(self, path):
                    # 读取中药名称
                    words = []  # 存储txt中药材名称
                    with open('class_names.txt', 'r') as f:  # 读取中药名称
                        for line in f:
                            word = line.strip()  # strip()方法用于去除每行末尾的字符(默认是空格和换行符)
                            words.append(word)  # 将当前行读到的单词添加到列表中
                    # 查找
                    with open('Chinese_herbal.txt', 'r') as t:  # 读取中药信息
                        temp = t.read()
                        line_list = temp.splitlines()
                        if len(path) <= 2:  # 二分查找思想
                            for i in range(91):
                                if words[i] == path:
                                    # 输出
                                    lines = line_list[i + 1].split('|')
                                    self.textBrowser_3.append(lines[2])
                                    self.textBrowser_4.append(lines[3])
                                    self.textBrowser_5.append(lines[4])
                                    self.textBrowser_6.append(lines[5])
                        else:
                            for i in range(91, 163):
                                if words[i] == path:
                                    # 输出
                                    lines = line_list[i + 1].split('|')
                                    self.textBrowser_3.append(lines[2])
                                    self.textBrowser_4.append(lines[3])
                                    self.textBrowser_5.append(lines[4])
                                    self.textBrowser_6.append(lines[5])

            app = QApplication(sys.argv)
            main_win = MainWindow()
            main_win.show()
            sys.exit(app.exec_())

完整代码

import sys
sys.path.append("libs")
import argparse
from basetrainer.utils import setup_config
from pybaseutils import file_utils, image_utils
from classifier import inference
from PyQt5.QtWidgets import QApplication, QMainWindow
from ui import Ui_MainWindow
from PyQt5 import QtGui

class Predictor(inference.Inference):
    def __init__(self, cfg):
        super(Predictor, self).__init__(cfg)

    def predict(self, image):
        """预测类别"""
        # pred_index类别标签,pred_score置信度得分
        pred_index, pred_score = self.inference(image)
        # 将模型输出的类别pred_index使用self.label2class_name函数进行转换,返回对应的类别名称
        pred_index = self.label2class_name(pred_index)
        return pred_index, pred_score

    def image_dir_predict(self, image_dir, vis=True, use_rgb=True, shuffle=False):
        """
        :param image_dir: list,*.txt ,图像路径或目录
        :param vis: 是否可视化
        :param use_rgb: 是否转换为RGB格式
        :param shuffle:  是否打乱顺序
        """
        pred_index_base = [] # 存储名称,穿进嵌套的类中
        formatted_string_base = [] # 存储准确率,穿进嵌套的类中
        image_list = file_utils.get_files_lists(image_dir, shuffle=shuffle) # 读取image_dir
        cnt = len(image_list) # 统计长度
        for path in image_list:
            image = image_utils.read_image_ch(path, use_rgb=use_rgb)  # 读取图像文件,对该文件进行解码并将其转换为numpy数组。
            pred_index, pred_score = self.predict(image)  # 预测类别
            value = float(pred_score[0])  # 准确率位数限制
            formatted_string = "%.5f" % value
            pred_index_base.append(pred_index[0])
            formatted_string_base.append(formatted_string)

        # 嵌套QT界面
        if vis:  # 图像显示
            class MainWindow(QMainWindow, Ui_MainWindow):
                def __init__(self):
                    super().__init__()
                    # 设置UI
                    self.setupUi(self)
                    self.tmp = 0 # 计数器
                    self.pushButton.clicked.connect(self.change_image) # 按钮

                def change_image(self):
                    if self.tmp == cnt: # 达到最后一张照片则
                        sys.exit() # 关闭界面
                    else:
                        self.textBrowser.clear() # 清空上一次的信息
                        self.textBrowser_2.clear()
                        self.textBrowser_3.clear()
                        self.textBrowser_4.clear()
                        self.textBrowser_5.clear()
                        self.textBrowser_6.clear()
                        path = image_list[self.tmp]
                        # 将图片设置为QLabel的内容
                        self.label_2.setPixmap(QtGui.QPixmap(path)) # 图片
                        self.textBrowser.append(pred_index_base[self.tmp]) # 名称
                        self.textBrowser_2.append(str(formatted_string_base[self.tmp])) # 准确率
                        self.BinarySearch(pred_index_base[self.tmp]) # 详情
                        self.tmps() # 计数器

                def tmps(self):
                    self.tmp += 1
                # 信息输出
                def BinarySearch(self, path):
                    # 读取中药名称
                    words = []  # 存储txt中药材名称
                    with open('class_names.txt', 'r') as f:  # 读取中药名称
                        for line in f:
                            word = line.strip()  # strip()方法用于去除每行末尾的字符(默认是空格和换行符)
                            words.append(word)  # 将当前行读到的单词添加到列表中
                    # 查找
                    with open('Chinese_herbal.txt', 'r') as t:  # 读取中药信息
                        temp = t.read()
                        line_list = temp.splitlines()
                        if len(path) <= 2:  # 二分查找思想
                            for i in range(91):
                                if words[i] == path:
                                    # 输出
                                    lines = line_list[i + 1].split('|')
                                    self.textBrowser_3.append(lines[2])
                                    self.textBrowser_4.append(lines[3])
                                    self.textBrowser_5.append(lines[4])
                                    self.textBrowser_6.append(lines[5])
                        else:
                            for i in range(91, 163):
                                if words[i] == path:
                                    # 输出
                                    lines = line_list[i + 1].split('|')
                                    self.textBrowser_3.append(lines[2])
                                    self.textBrowser_4.append(lines[3])
                                    self.textBrowser_5.append(lines[4])
                                    self.textBrowser_6.append(lines[5])

            app = QApplication(sys.argv)
            main_win = MainWindow()
            main_win.show()
            sys.exit(app.exec_())


def get_parser():
    # 配置文件
    # config_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20230404151914/config.yaml"
    # # 模型文件
    # model_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20230404151914/model/best_model_116_98.4700.pth"
    # 配置文件
    config_file = "work_space/resnet34_1.0_CrossEntropyLoss_20230524125658/config.yaml"
    # 模型文件
    model_file = "work_space/resnet34_1.0_CrossEntropyLoss_20230524125658/model/best_model_114_98.5700.pth"

    # 待测试图片目录
    image_dir = "data/test_images"
    # 创建用于解析命令行参数的ArgumentParser对象;用于读取命令行输入,并将输入转换为其他数据类型以供程序使用。
    parser = argparse.ArgumentParser(description="Inference Argument")
    # 使用argparse库来解析命令行参数的函数。
    # ("-c", "--config_file":参数名称,分别代表短选项何长选项;
    # help=:该选项的描述信息;default=:不提供选项则使用默认值;type=:类型)
    parser.add_argument("-c", "--config_file", help="configs file", default=config_file, type=str)
    parser.add_argument("-m", "--model_file", help="model_file", default=model_file, type=str)
    parser.add_argument("--device", help="cuda device id", default="cuda:0", type=str)
    parser.add_argument("--image_dir", help="image file or directory", default=image_dir, type=str)
    return parser


if __name__ == "__main__":
    # 路径——>解析器——>预测分类——>实现
    parser = get_parser()
    # cfg配置对象;setup_config解析命令行参数(parser.parse_args():命令行参数解析器, cfg_updata使用默认值)。
    cfg = setup_config.parser_config(parser.parse_args(), cfg_updata=False)
    t = Predictor(cfg)
    # (image_dir图像目录,shuffle=False表示不要打乱目录中文件的顺序(默认为True,即打乱))
    t.image_dir_predict(cfg.image_dir, shuffle=False)

项目不足

  • 内存空间大
    当识别的图片过多时,存储名称、准确率的列表需求空间较大。
  • 测试对象静态
    后续可将图片换为视频或者实时捕获。
### 回答1: 表情识别代码是在PyTorch框架下实现的一种图像处理技术。通过使用UI(用户界面),我们可以使这个代码更加友好和易于使用。 PyTorch是一种流行的深度学习框架,它提供了丰富的工具和函数来构建和训练神经网络模型。在表情识别中,我们可以使用PyTorch来构建一个卷积神经网络(CNN)模型,以识别输入图片中的表情。 UI是指用户界面,它是让用户与计算机程序进行交互的一种方式。通过使用UI,我们可以将表情识别代码制作成一个交互式的应用程序,以便用户可以直观地使用这个功能。 在UI中,我们可以添加一个文件选择按钮,让用户选择要识别的表情图片。然后,我们可以添加一个“识别”按钮,当用户点击它时,代码会调用PyTorch模型来对选择的图片进行表情识别识别结果可以通过界面上的文本框或图像显示出来。 此外,我们还可以添加一些其他的功能,如显示当前选择的图片、预处理图片、调整模型参数等。这些功能可以使用户更方便地使用和了解表情识别代码。 总之,通过将表情识别代码与PyTorch和UI结合起来,我们可以实现一个功能强大、易于使用的表情识别应用程序。用户可以通过界面直观地选择和识别表情,这大大提高了代码的可用性和用户体验。 ### 回答2: 表情识别代码是指使用pytorch框架开发的一种图像处理代码,用于识别人脸表情。这种代码通常使用了深度学习的方法,通过对输入图像进行分类来识别出人脸的表情。 在pytorch框架中,可以使用torchvision库提供的一些预训练的模型来进行表情识别。常见的预训练模型有VGGNet、ResNet等,它们能够提取图像的特征信息。我们可以利用这些预训练模型,将图像输入网络中,经过前向传播得到输出结果。 代码中首先需要导入相关的库和模块,例如torch、torchvision以及相关的数据集等。然后,可以定义一个网络模型,可以选择使用预训练模型或自己设计模型。接着,需要设置模型的超参数,如学习率、优化器等。然后定义训练和测试的过程,包括数据加载、前向传播、计算损失、反向传播以及更新参数等。最后,可以对模型进行训练和测试,分别输出模型在训练集和测试集上的准确率。 在实际运行时,可以使用pytorch的图形用户界面(UI)库,如PyQt或Tkinter等,来设计一个用户友好的界面。通过该界面,用户可以选择图片或视频作为输入,然后点击按钮进行表情识别,最后显示结果在界面上。这样,用户就可以直观地看到图像的表情识别结果。 总之,表情识别代码pytorch ui是指使用pytorch框架开发的一个具有图形界面的表情识别代码,能够通过图像输入进行表情的分类识别,并将结果可视化展示给用户。 ### 回答3: 表情识别是一种利用计算机视觉技术和机器学习算法识别人脸表情的应用。PyTorch是一个基于Python的机器学习框架,提供了丰富的工具和库来构建和训练深度学习模型。 对于表情识别代码的编写,我们可以使用PyTorch来实现。首先,需要收集带有不同表情的人脸图像数据集。这些图像应包含各种表情,如开心、悲伤、惊讶等。然后,可以使用PyTorch提供的图像处理库来对这些图像进行预处理,例如裁剪、缩放和归一化。 接下来,我们可以使用PyTorch来构建卷积神经网络(CNN)模型。CNN是一种用于图像识别的深度学习模型,在图像分类任务中表现出色。我们可以使用PyTorch提供的函数和类来构建网络结构,例如卷积层、池化层和全连接层。同时,我们还可以使用PyTorch的自动求导功能来计算和优化模型参数。 在模型构建完成后,我们需要将数据集分为训练集和测试集。训练集用于训练模型参数,测试集用于评估模型性能。可以使用PyTorch提供的数据加载器和数据拆分函数来实现这一过程。 然后,我们可以使用PyTorch的优化器和损失函数来进行模型训练和优化。通过迭代训练和调整模型参数,我们可以使模型逐渐提高表情识别的准确度。 最后,我们可以使用PyTorch搭建一个简单的用户界面(UI)来进行表情识别。可以使用PyTorch的图像处理库来处理用户提供的图像输入,并应用训练好的模型来识别表情。通过将识别结果显示在UI上,用户即可得到相应的表情识别结果。 综上所述,通过PyTorch构建表情识别代码和用户界面,我们可以实现对人脸表情进行自动识别和分类的功能。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

应用数学

只要想学即可

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值