python cli_demo.py
import os
import platform
import signal
from transformers import AutoTokenizer, AutoModel
import readline
tokenizer = AutoTokenizer.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
def build_prompt(history):
prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM-6B:{response}"
return prompt
def signal_handler(signal, frame):
global stop_stream
stop_stream = True
def main():
history = []
global stop_stream
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
os.system(clear_command)
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
continue
count = 0
for response, history in model.stream_chat(tokenizer, query, history=history):
if stop_stream:
stop_stream = False
break
else:
count += 1
if count % 8 == 0:
os.system(clear_command)
print(build_prompt(history), flush=True)
signal.signal(signal.SIGINT, signal_handler)
os.system(clear_command)
print(build_prompt(history), flush=True)
if __name__ == "__main__":
main()
Python脚本,使用Hugging Face的transformers库,
具体使用了ChatGLM-6B模型进行聊天对话。下面逐行解释这段代码:
-
import os
: 导入os模块,用于访问操作系统功能。 -
import platform
: 导入platform模块,用于获取操作系统信息。 -
import signal
: 导入signal模块,用于处理信号。 -
from transformers import AutoTokenizer, AutoModel
: 从transformers库导入AutoTokenizer和AutoModel,用于加载预训练的模型和对应的tokenizer。 -
import readline
: 导入readline模块,用于Python控制台的输入。 -
tokenizer = AutoTokenizer.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True)
: 加载预训练模型的tokenizer,从相对路径"../ChatGLM-Tuning-master/chatglm-6b"。 -
model = AutoModel.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True).half().cuda()
: 加载预训练模型,并将其转移到GPU上,同时使用半精度浮点数(half-precision floating point)来提高运算速度。 -
model = model.eval()
: 将模型设置为评估模式,通常在测试或验证阶段使用。 -
os_name = platform.system()
: 获取操作系统名称。 -
clear_command = 'cls' if os_name == 'Windows' else 'clear'
: 根据操作系统类型设置清屏命令。 -
stop_stream = False
: 定义一个全局变量stop_stream,用于控制是否停止模型的流式对话。
12-18. def build_prompt(history)
: 定义一个函数,根据历史对话构建提示文本。
19-22. def signal_handler(signal, frame)
: 定义一个信号处理函数,当接收到中断信号时,改变全局变量stop_stream的值。
23-47. def main()
: 定义主函数,处理用户输入和模型响应的交互,可以响应"stop"和"clear"命令,可以通过Ctrl+C来中断模型的响应。
48-50. if __name__ == "__main__": main()
: 如果该文件被直接运行(而不是作为模块导入),则调用main()函数。
接下来我们深入分析一下主要的函数和逻辑。
-
def build_prompt(history):
这个函数接收一个历史对话的列表,然后将其格式化为一个提示文本,包括用户和ChatGLM-6B的所有对话内容。 -
def signal_handler(signal, frame):
这是一个信号处理函数,它会在接收到特定信号(如用户按下Ctrl+C)时被调用。函数的作用是将全局变量stop_stream
设为True,从而在主循环中用来停止模型的流式对话。 -
def main():
这是主函数,处理用户输入和模型响应的交互。它首先定义一个空的历史对话列表,然后进入一个无限循环,在循环中等待用户输入,并对用户输入做出响应。这里有几个关键的部分:query = input("\n用户:")
: 等待用户输入。if query.strip() == "stop": break
: 如果用户输入"stop",则跳出循环,结束程序。if query.strip() == "clear":
如果用户输入"clear",则清空历史对话列表,清屏,并打印欢迎语句。for response, history in model.stream_chat(tokenizer, query, history=history):
使用模型进行流式聊天,模型将在每个步骤生成一个响应,并更新历史对话。这是一个阻塞操作,会等待模型生成响应。if stop_stream:
如果全局变量stop_stream
为True(表示收到了中断信号),则停止流式对话,并将stop_stream
设回False。if count % 8 == 0:
如果已经生成了8个响应,就清屏并打印历史对话,然后继续等待模型生成更多的响应。
-
if __name__ == "__main__": main()
这是Python的常见模式,只有当脚本被直接运行时,才会执行main()
函数。如果脚本被作为模块导入,main()
函数就不会被执行。
下面我将详细地进一步解释一下几个重要的语句和逻辑。
-
tokenizer = AutoTokenizer.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True)
: 这一行从本地路径加载了一个预训练的tokenizer。这个tokenizer用于将原始文本输入转化为模型可以理解的形式。参数trust_remote_code=True
是说信任远程代码,通常与Hugging Face模型库中的自定义模型有关。 -
model = AutoModel.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True).half().cuda()
: 这行代码从本地路径加载了一个预训练的模型,并且使用半精度(half-precision)和GPU(通过.cuda())来进行计算。半精度可以加速模型的计算速度,而牺牲一部分的精度。 -
for response, history in model.stream_chat(tokenizer, query, history=history):
这里开始了一个流式的聊天对话。模型根据给定的历史对话和用户的最新输入生成响应,生成一个响应后,就立即返回,然后继续生成下一个响应。在每个步骤中,都会返回新的响应和更新后的历史对话。 -
if count % 8 == 0:
这是一个简单的计数逻辑,每生成8个响应,就清屏并打印所有的历史对话。这样可以保证屏幕上不会有太多的文本。 -
signal.signal(signal.SIGINT, signal_handler)
: 这行代码设置了一个信号处理函数,当接收到中断信号(如用户按下Ctrl+C)时,会调用signal_handler
函数。这样用户可以通过按Ctrl+C来中断模型的流式对话。