代码库: QChatBot_GLM
核心代码:
# -*- coding: utf-8 -*-
import os
import sys
import time
import timeit
import sqlite3
import traceback
import psutil
from PyQt5 import QtWidgets, QtCore, QtGui
import qtawesome
import torch
from modelscope import AutoModel, AutoTokenizer
SUPPORTED_MODELS = (
"chatglm-6b",
"chatglm2-6b-32k", "chatglm2-6b",
"chatglm3-6b-32k", "chatglm3-6b",
"glm-4-9b-chat", "glm-4-9b-chat-1m",
)
# ~~~~~~~~~~~~~~ 参数设置区 ~~~~~~~~~~~~~
# 用启动BAT设置参数
# os.environ['model_path'] = r'E:\AI\models\ZhipuAI\glm-4-9b-chat'
# os.environ['top_p'] = '0.8'
# os.environ['temperature'] = '0.9'
# os.environ['max_length'] = '32768'
# os.environ['log_file'] = 'chatlog.db3'
# 读取参数
settings = dict(
model_path=os.environ['model_path'],
top_p=float(os.environ['top_p']),
temperature=float(os.environ['temperature']),
max_length=int(os.environ['max_length']),
log_file=os.environ['log_file'],
)
# ~~~~~~~~~~~~~~ 用SQLTIE数据库记录对话内容 ~~~~~~~~~~~~~
# 初始化数据库
def init_chat_database():
logfile = settings['log_file']
conn = sqlite3.connect(logfile)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS chat
(id INTEGER PRIMARY KEY AUTOINCREMENT,
user TEXT,
bot TEXT,
time TEXT)''')
conn.commit()
conn.close()
# 保存对话内容
def save_chat_to_database(user, bot, time):
logfile = settings['log_file']
conn = sqlite3.connect(logfile)
c = conn.cursor()
c.execute("INSERT INTO chat (user, bot, time) VALUES (?, ?, ?)", (user, bot, time))
conn.commit()
conn.close()
# ~~~~~~~~~~~~~~ 资源监控器 ~~~~~~~~~~~~~
class QCpuGpuMonitor(QtWidgets.QWidget):
def __init__(self):
super(QCpuGpuMonitor, self).__init__()
self.initUI()
def initUI(self):
# 速度计数
self.token_label = QtWidgets.QLabel('0 token/s', self)
self.token_label.setToolTip('token/s')
self.token_label.setFrameStyle(QtWidgets.QFrame.Shape.StyledPanel)
# 模型指示
self.model_label = QtWidgets.QLabel(os.path.split(settings['model_path'])[-1], self)
self.model_label.setToolTip('当前语言模型')
self.model_label.setFrameStyle(QtWidgets.QFrame.Shape.StyledPanel)
# CPU
self.cpu_label = QtWidgets.QLabel('CPU: N/A')
self.cpu_label.setToolTip('CPU 使用情况')
self.cpu_label.setFrameStyle(QtWidgets.QFrame.Shape.StyledPanel)
# GPU
self.gpu_label = QtWidgets.QLabel('GPU: N/A')
self.gpu_label.setToolTip('GPU 使用情况')
self.gpu_label.setFrameStyle(QtWidgets.QFrame.Shape.StyledPanel)
layout = QtWidgets.QHBoxLayout(self)
layout.addWidget(self.token_label)
layout.addWidget(self.model_label)
layout.addWidget(self.cpu_label)
layout.addWidget(self.gpu_label)
self.setContentsMargins(0, 0, 0, 0)
layout.setContentsMargins(0, 0, 0, 0)
self.timer = QtCore.QTimer(self)
self.timer.timeout.connect(self.update_status)
self.timer.start(250) # 每秒更新一次
def update_status(self):
try:
device = torch.cuda.current_device()
free, all = torch.cuda.mem_get_info(device)
free = free / 1024 / 1024 # MB
all = all / 1024 / 1024 # MB
used = all - free
str_gpu = '%dMB/%dMB' % (used, all)
except:
str_gpu = 'N/A'
try:
m = psutil.virtual_memory()
all = m.total / 1024 / 1024 / 1024 # GB
used = m.used / 1024 / 1024 / 1024 # GB
str_cpu = '%0.1fGB/%0.1fGB' % (used, all)
except:
str_cpu = 'N/A'
self.cpu_label.setText('CPU: ' + str_cpu)
self.gpu_label.setText('GPU: ' + str_gpu)
def update_token_rate(self, token_rate):
self.token_label.setText('%0.2f token/s' % token_rate)
def test_QCpuGpuMonitor(): # pass
app = QtWidgets.QApplication([])
monitor = QCpuGpuMonitor()
monitor.show()
app.exec_()
# if __name__ == '__main__':
# test_QCpuGpuMonitor()
# ~~~~~~~~~~~~~~ 载入模型 ~~~~~~~~~~~~~
tokenizer = AutoTokenizer.from_pretrained(
settings['model_path'],
local_files_only=True,
trust_remote_code=True
)
model = AutoModel.from_pretrained(
settings['model_path'],
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
local_files_only=True
)
def process_response(parent, prompt, *args, **kwargs):
"""
:param parent: 父窗口,传递状态参数
:param prompt: 输入的文本
:param args:
:param kwargs: 包含 signals, max_length, top_p, temperature
:return:
"""
global model
global tokenizer
signals = kwargs.get('signals')
if model is None or tokenizer is None:
return
max_length = settings['max_length']
top_p = settings['top_p']
temperature = settings['temperature']
response = ''
time_init = timeit.default_timer()
for response, history in model.stream_chat(tokenizer, prompt, history=None, return_past_key_values=False,
max_length=max_length, top_p=top_p, temperature=temperature):
is_stopped = parent.is_stopped
if is_stopped:
signals.terminated.emit()
break
time_current = timeit.default_timer()
signals.floatNumber.emit(time_current - time_init)
time_init = time_current
signals.status.emit(response) # generating
return response # 自动触发 result 信号
class QChatBot(QtWidgets.QMainWindow):
def __init__(self):
super(QChatBot, self).__init__()
self.setupUi()
self.initData()
def setupUi(self):
color_skyblue = QtGui.QColor(0, 128, 255)
color_orange = QtGui.QColor(255, 165, 0)
icon = qtawesome.icon('fa.weixin', color=color_skyblue)
self.setWindowIcon(icon)
self.setWindowTitle('QSimpleChatBot')
self.textBot = QtWidgets.QPlainTextEdit()
self.textBot.setReadOnly(True)
self.textUsr = QtWidgets.QPlainTextEdit()
self.textUsr.setObjectName('textUsr')
self.textUsr.setFixedHeight(80)
self.button = QtWidgets.QPushButton()
self.button.setFixedSize(80, 80)
self.button.setIconSize(QtCore.QSize(64, 64))
self.button.setObjectName('button')
self.icon_send = qtawesome.icon('fa.send', color=color_skyblue)
self.icon_stop = qtawesome.icon('fa.stop', color=color_orange)
self.button.setIcon(self.icon_send)
mainWidget = QtWidgets.QWidget(self)
layout = QtWidgets.QVBoxLayout(mainWidget)
layout.addWidget(self.textBot)
lay = QtWidgets.QHBoxLayout()
layout.addLayout(lay)
lay.addWidget(self.textUsr)
lay.addWidget(self.button)
self.setCentralWidget(mainWidget)
self.statusbar = QtWidgets.QStatusBar(self)
self.statusbar.setObjectName("statusbar")
self.setStatusBar(self.statusbar)
# 监视器
self.monitor = QCpuGpuMonitor()
self.statusbar.addPermanentWidget(self.monitor)
QtCore.QMetaObject.connectSlotsByName(self)
def initData(self):
# thread pool
self.threadpool = QtCore.QThreadPool()
self.is_stopped = True
init_chat_database()
@QtCore.pyqtSlot()
def on_button_clicked(self):
# 如果正在生成,即上一个点击操作还没结束,则终止
if not self.is_stopped:
self.is_stopped = True
self.button.setIcon(self.icon_send)
return
# 结束状态,准备发送命令
prompt = self.textUsr.toPlainText().strip()
if not prompt:
return
self.is_stopped = False
self.button.setIcon(self.icon_stop)
worker = Worker(process_response, self, prompt)
worker.signals.status.connect(self.update_progress)
worker.signals.result.connect(self.update_result)
worker.signals.floatNumber.connect(self.update_token_rate)
self.threadpool.start(worker)
def update_progress(self, response, *args, **kwargs):
self.textBot.setPlainText(response)
bar = self.textBot.verticalScrollBar()
bar.setValue(bar.maximum())
def update_result(self, response):
# 更新文本
self.textBot.setPlainText(response)
bar = self.textBot.verticalScrollBar()
bar.setValue(bar.maximum())
prompt = self.textUsr.toPlainText().strip()
save_chat_to_database(prompt, response, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
# set status
self.is_stopped = True
self.button.setIcon(self.icon_send)
def update_token_rate(self, time_per_token):
self.monitor.update_token_rate(1. / time_per_token)
class WorkerSignals(QtCore.QObject):
"""
Defines the signal structure of Worker.
"""
terminated = QtCore.pyqtSignal()
finished = QtCore.pyqtSignal()
error = QtCore.pyqtSignal(tuple)
result = QtCore.pyqtSignal(object)
progress = QtCore.pyqtSignal(int)
floatNumber = QtCore.pyqtSignal(float)
progressNumberOfTotal = QtCore.pyqtSignal(int, int) # 2/100
progressPercentage = QtCore.pyqtSignal(float) # 0~1.0
status = QtCore.pyqtSignal(str)
class Worker(QtCore.QRunnable):
"""
Worker thread
"""
def __init__(self, fn, *args, **kwargs):
super(Worker, self).__init__()
self.fn = fn
self.args = args
self.kwargs = kwargs
self.signals = WorkerSignals()
kwargs['signals'] = self.signals # fn 可以获取信号,从而发出过程信息
@QtCore.pyqtSlot()
def run(self):
"""
Initialise the runner function with passed args, kwargs.
"""
try:
result = self.fn(*self.args, **self.kwargs)
except:
traceback.print_exc()
exctype, value = sys.exc_info()[:2]
self.signals.error.emit((exctype, value, traceback.format_exc()))
else:
self.signals.result.emit(result)
finally:
self.signals.finished.emit()
if __name__ == '__main__':
app = QtWidgets.QApplication(sys.argv)
mw = QChatBot()
mw.show()
sys.exit(app.exec_())