基于PyQt5的GLM系列大模型简单对话

运行效果
代码库: QChatBot_GLM

核心代码:

# -*- coding: utf-8 -*-
import os
import sys
import time
import timeit
import sqlite3
import traceback

import psutil
from PyQt5 import QtWidgets, QtCore, QtGui
import qtawesome
import torch
from modelscope import AutoModel, AutoTokenizer

SUPPORTED_MODELS = (
    "chatglm-6b",
    "chatglm2-6b-32k", "chatglm2-6b",
    "chatglm3-6b-32k", "chatglm3-6b",
    "glm-4-9b-chat", "glm-4-9b-chat-1m",
)

# ~~~~~~~~~~~~~~ 参数设置区 ~~~~~~~~~~~~~
# 用启动BAT设置参数
# os.environ['model_path'] = r'E:\AI\models\ZhipuAI\glm-4-9b-chat'
# os.environ['top_p'] = '0.8'
# os.environ['temperature'] = '0.9'
# os.environ['max_length'] = '32768'
# os.environ['log_file'] = 'chatlog.db3'

# 读取参数
settings = dict(
    model_path=os.environ['model_path'],
    top_p=float(os.environ['top_p']),
    temperature=float(os.environ['temperature']),
    max_length=int(os.environ['max_length']),
    log_file=os.environ['log_file'],
)


# ~~~~~~~~~~~~~~ 用SQLTIE数据库记录对话内容 ~~~~~~~~~~~~~
# 初始化数据库
def init_chat_database():
    logfile = settings['log_file']
    conn = sqlite3.connect(logfile)
    c = conn.cursor()
    c.execute('''CREATE TABLE IF NOT EXISTS chat
                 (id INTEGER PRIMARY KEY AUTOINCREMENT,
                 user TEXT,
                 bot TEXT,
                 time TEXT)''')
    conn.commit()
    conn.close()


# 保存对话内容
def save_chat_to_database(user, bot, time):
    logfile = settings['log_file']
    conn = sqlite3.connect(logfile)
    c = conn.cursor()
    c.execute("INSERT INTO chat (user, bot, time) VALUES (?, ?, ?)", (user, bot, time))
    conn.commit()
    conn.close()


# ~~~~~~~~~~~~~~ 资源监控器 ~~~~~~~~~~~~~

class QCpuGpuMonitor(QtWidgets.QWidget):
    def __init__(self):
        super(QCpuGpuMonitor, self).__init__()
        self.initUI()

    def initUI(self):
        # 速度计数
        self.token_label = QtWidgets.QLabel('0 token/s', self)
        self.token_label.setToolTip('token/s')
        self.token_label.setFrameStyle(QtWidgets.QFrame.Shape.StyledPanel)
        # 模型指示
        self.model_label = QtWidgets.QLabel(os.path.split(settings['model_path'])[-1], self)
        self.model_label.setToolTip('当前语言模型')
        self.model_label.setFrameStyle(QtWidgets.QFrame.Shape.StyledPanel)
        # CPU
        self.cpu_label = QtWidgets.QLabel('CPU: N/A')
        self.cpu_label.setToolTip('CPU 使用情况')
        self.cpu_label.setFrameStyle(QtWidgets.QFrame.Shape.StyledPanel)
        # GPU
        self.gpu_label = QtWidgets.QLabel('GPU: N/A')
        self.gpu_label.setToolTip('GPU 使用情况')
        self.gpu_label.setFrameStyle(QtWidgets.QFrame.Shape.StyledPanel)

        layout = QtWidgets.QHBoxLayout(self)
        layout.addWidget(self.token_label)
        layout.addWidget(self.model_label)
        layout.addWidget(self.cpu_label)
        layout.addWidget(self.gpu_label)
        self.setContentsMargins(0, 0, 0, 0)
        layout.setContentsMargins(0, 0, 0, 0)

        self.timer = QtCore.QTimer(self)

        self.timer.timeout.connect(self.update_status)

        self.timer.start(250)  # 每秒更新一次

    def update_status(self):
        try:
            device = torch.cuda.current_device()
            free, all = torch.cuda.mem_get_info(device)
            free = free / 1024 / 1024  # MB
            all = all / 1024 / 1024  # MB
            used = all - free
            str_gpu = '%dMB/%dMB' % (used, all)
        except:
            str_gpu = 'N/A'

        try:
            m = psutil.virtual_memory()
            all = m.total / 1024 / 1024 / 1024  # GB
            used = m.used / 1024 / 1024 / 1024  # GB
            str_cpu = '%0.1fGB/%0.1fGB' % (used, all)
        except:
            str_cpu = 'N/A'

        self.cpu_label.setText('CPU: ' + str_cpu)
        self.gpu_label.setText('GPU: ' + str_gpu)

    def update_token_rate(self, token_rate):
        self.token_label.setText('%0.2f token/s' % token_rate)


def test_QCpuGpuMonitor():  # pass

    app = QtWidgets.QApplication([])
    monitor = QCpuGpuMonitor()
    monitor.show()
    app.exec_()


# if __name__ == '__main__':
#     test_QCpuGpuMonitor()

# ~~~~~~~~~~~~~~ 载入模型 ~~~~~~~~~~~~~
tokenizer = AutoTokenizer.from_pretrained(
    settings['model_path'],
    local_files_only=True,
    trust_remote_code=True
)
model = AutoModel.from_pretrained(
    settings['model_path'],
    device_map="auto",
    torch_dtype="auto",
    trust_remote_code=True,
    local_files_only=True
)


def process_response(parent, prompt, *args, **kwargs):
    """
    :param parent: 父窗口,传递状态参数
    :param prompt: 输入的文本
    :param args:
    :param kwargs: 包含 signals, max_length, top_p, temperature
    :return:
    """
    global model
    global tokenizer
    signals = kwargs.get('signals')
    if model is None or tokenizer is None:
        return

    max_length = settings['max_length']
    top_p = settings['top_p']
    temperature = settings['temperature']

    response = ''

    time_init = timeit.default_timer()
    for response, history in model.stream_chat(tokenizer, prompt, history=None, return_past_key_values=False,
                                               max_length=max_length, top_p=top_p, temperature=temperature):
        is_stopped = parent.is_stopped
        if is_stopped:
            signals.terminated.emit()
            break
        time_current = timeit.default_timer()
        signals.floatNumber.emit(time_current - time_init)
        time_init = time_current
        signals.status.emit(response)  # generating

    return response  # 自动触发 result 信号


class QChatBot(QtWidgets.QMainWindow):
    def __init__(self):
        super(QChatBot, self).__init__()
        self.setupUi()
        self.initData()

    def setupUi(self):
        color_skyblue = QtGui.QColor(0, 128, 255)
        color_orange = QtGui.QColor(255, 165, 0)

        icon = qtawesome.icon('fa.weixin', color=color_skyblue)
        self.setWindowIcon(icon)
        self.setWindowTitle('QSimpleChatBot')

        self.textBot = QtWidgets.QPlainTextEdit()
        self.textBot.setReadOnly(True)
        self.textUsr = QtWidgets.QPlainTextEdit()
        self.textUsr.setObjectName('textUsr')
        self.textUsr.setFixedHeight(80)
        self.button = QtWidgets.QPushButton()
        self.button.setFixedSize(80, 80)
        self.button.setIconSize(QtCore.QSize(64, 64))
        self.button.setObjectName('button')
        self.icon_send = qtawesome.icon('fa.send', color=color_skyblue)
        self.icon_stop = qtawesome.icon('fa.stop', color=color_orange)
        self.button.setIcon(self.icon_send)

        mainWidget = QtWidgets.QWidget(self)
        layout = QtWidgets.QVBoxLayout(mainWidget)
        layout.addWidget(self.textBot)

        lay = QtWidgets.QHBoxLayout()
        layout.addLayout(lay)
        lay.addWidget(self.textUsr)
        lay.addWidget(self.button)

        self.setCentralWidget(mainWidget)

        self.statusbar = QtWidgets.QStatusBar(self)
        self.statusbar.setObjectName("statusbar")
        self.setStatusBar(self.statusbar)

        # 监视器
        self.monitor = QCpuGpuMonitor()
        self.statusbar.addPermanentWidget(self.monitor)

        QtCore.QMetaObject.connectSlotsByName(self)

    def initData(self):
        # thread pool
        self.threadpool = QtCore.QThreadPool()
        self.is_stopped = True
        init_chat_database()

    @QtCore.pyqtSlot()
    def on_button_clicked(self):
        # 如果正在生成,即上一个点击操作还没结束,则终止
        if not self.is_stopped:
            self.is_stopped = True
            self.button.setIcon(self.icon_send)
            return

        # 结束状态,准备发送命令
        prompt = self.textUsr.toPlainText().strip()
        if not prompt:
            return

        self.is_stopped = False
        self.button.setIcon(self.icon_stop)

        worker = Worker(process_response, self, prompt)
        worker.signals.status.connect(self.update_progress)
        worker.signals.result.connect(self.update_result)
        worker.signals.floatNumber.connect(self.update_token_rate)
        self.threadpool.start(worker)

    def update_progress(self, response, *args, **kwargs):
        self.textBot.setPlainText(response)
        bar = self.textBot.verticalScrollBar()
        bar.setValue(bar.maximum())

    def update_result(self, response):
        # 更新文本
        self.textBot.setPlainText(response)
        bar = self.textBot.verticalScrollBar()
        bar.setValue(bar.maximum())

        prompt = self.textUsr.toPlainText().strip()

        save_chat_to_database(prompt, response, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

        # set status
        self.is_stopped = True
        self.button.setIcon(self.icon_send)

    def update_token_rate(self, time_per_token):
        self.monitor.update_token_rate(1. / time_per_token)


class WorkerSignals(QtCore.QObject):
    """
    Defines the signal structure of Worker.
    """
    terminated = QtCore.pyqtSignal()
    finished = QtCore.pyqtSignal()
    error = QtCore.pyqtSignal(tuple)
    result = QtCore.pyqtSignal(object)
    progress = QtCore.pyqtSignal(int)
    floatNumber = QtCore.pyqtSignal(float)
    progressNumberOfTotal = QtCore.pyqtSignal(int, int)  # 2/100
    progressPercentage = QtCore.pyqtSignal(float)  # 0~1.0
    status = QtCore.pyqtSignal(str)


class Worker(QtCore.QRunnable):
    """
    Worker thread
    """
    def __init__(self, fn, *args, **kwargs):
        super(Worker, self).__init__()
        self.fn = fn
        self.args = args
        self.kwargs = kwargs
        self.signals = WorkerSignals()
        kwargs['signals'] = self.signals  # fn 可以获取信号,从而发出过程信息

    @QtCore.pyqtSlot()
    def run(self):
        """
        Initialise the runner function with passed args, kwargs.
        """
        try:
            result = self.fn(*self.args, **self.kwargs)
        except:
            traceback.print_exc()
            exctype, value = sys.exc_info()[:2]
            self.signals.error.emit((exctype, value, traceback.format_exc()))
        else:
            self.signals.result.emit(result)
        finally:
            self.signals.finished.emit()

if __name__ == '__main__':
    app = QtWidgets.QApplication(sys.argv)
    mw = QChatBot()
    mw.show()
    sys.exit(app.exec_())

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值