ChatGLM-6B源码解析 之 cli_demo.py

python cli_demo.py

 

import os
import platform
import signal
from transformers import AutoTokenizer, AutoModel
import readline

tokenizer = AutoTokenizer.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False


def build_prompt(history):
    prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
    for query, response in history:
        prompt += f"\n\n用户:{query}"
        prompt += f"\n\nChatGLM-6B:{response}"
    return prompt


def signal_handler(signal, frame):
    global stop_stream
    stop_stream = True


def main():
    history = []
    global stop_stream
    print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
    while True:
        query = input("\n用户:")
        if query.strip() == "stop":
            break
        if query.strip() == "clear":
            history = []
            os.system(clear_command)
            print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
            continue
        count = 0
        for response, history in model.stream_chat(tokenizer, query, history=history):
            if stop_stream:
                stop_stream = False
                break
            else:
                count += 1
                if count % 8 == 0:
                    os.system(clear_command)
                    print(build_prompt(history), flush=True)
                    signal.signal(signal.SIGINT, signal_handler)
        os.system(clear_command)
        print(build_prompt(history), flush=True)


if __name__ == "__main__":
    main()

Python脚本,使用Hugging Face的transformers库,

具体使用了ChatGLM-6B模型进行聊天对话。下面逐行解释这段代码:

  1. import os: 导入os模块,用于访问操作系统功能。

  2. import platform: 导入platform模块,用于获取操作系统信息。

  3. import signal: 导入signal模块,用于处理信号。

  4. from transformers import AutoTokenizer, AutoModel: 从transformers库导入AutoTokenizer和AutoModel,用于加载预训练的模型和对应的tokenizer。

  5. import readline: 导入readline模块,用于Python控制台的输入。

  6. tokenizer = AutoTokenizer.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True): 加载预训练模型的tokenizer,从相对路径"../ChatGLM-Tuning-master/chatglm-6b"。

  7. model = AutoModel.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True).half().cuda(): 加载预训练模型,并将其转移到GPU上,同时使用半精度浮点数(half-precision floating point)来提高运算速度。

  8. model = model.eval(): 将模型设置为评估模式,通常在测试或验证阶段使用。

  9. os_name = platform.system(): 获取操作系统名称。

  10. clear_command = 'cls' if os_name == 'Windows' else 'clear': 根据操作系统类型设置清屏命令。

  11. stop_stream = False: 定义一个全局变量stop_stream,用于控制是否停止模型的流式对话。

12-18. def build_prompt(history): 定义一个函数,根据历史对话构建提示文本。

19-22. def signal_handler(signal, frame): 定义一个信号处理函数,当接收到中断信号时,改变全局变量stop_stream的值。

23-47. def main(): 定义主函数,处理用户输入和模型响应的交互,可以响应"stop"和"clear"命令,可以通过Ctrl+C来中断模型的响应。

48-50. if __name__ == "__main__": main(): 如果该文件被直接运行(而不是作为模块导入),则调用main()函数。

接下来我们深入分析一下主要的函数和逻辑

  • def build_prompt(history): 这个函数接收一个历史对话的列表,然后将其格式化为一个提示文本,包括用户和ChatGLM-6B的所有对话内容。

  • def signal_handler(signal, frame): 这是一个信号处理函数,它会在接收到特定信号(如用户按下Ctrl+C)时被调用。函数的作用是将全局变量stop_stream设为True,从而在主循环中用来停止模型的流式对话。

  • def main(): 这是主函数,处理用户输入和模型响应的交互。它首先定义一个空的历史对话列表,然后进入一个无限循环,在循环中等待用户输入,并对用户输入做出响应。这里有几个关键的部分:

    • query = input("\n用户:"): 等待用户输入。
    • if query.strip() == "stop": break: 如果用户输入"stop",则跳出循环,结束程序。
    • if query.strip() == "clear": 如果用户输入"clear",则清空历史对话列表,清屏,并打印欢迎语句。
    • for response, history in model.stream_chat(tokenizer, query, history=history): 使用模型进行流式聊天,模型将在每个步骤生成一个响应,并更新历史对话。这是一个阻塞操作,会等待模型生成响应。
    • if stop_stream: 如果全局变量stop_stream为True(表示收到了中断信号),则停止流式对话,并将stop_stream设回False。
    • if count % 8 == 0: 如果已经生成了8个响应,就清屏并打印历史对话,然后继续等待模型生成更多的响应。
  • if __name__ == "__main__": main() 这是Python的常见模式,只有当脚本被直接运行时,才会执行main()函数。如果脚本被作为模块导入,main()函数就不会被执行。

下面我将详细地进一步解释一下几个重要的语句和逻辑。

  1. tokenizer = AutoTokenizer.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True): 这一行从本地路径加载了一个预训练的tokenizer。这个tokenizer用于将原始文本输入转化为模型可以理解的形式。参数trust_remote_code=True是说信任远程代码,通常与Hugging Face模型库中的自定义模型有关。

  2. model = AutoModel.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True).half().cuda(): 这行代码从本地路径加载了一个预训练的模型,并且使用半精度(half-precision)和GPU(通过.cuda())来进行计算。半精度可以加速模型的计算速度,而牺牲一部分的精度。

  3. for response, history in model.stream_chat(tokenizer, query, history=history): 这里开始了一个流式的聊天对话。模型根据给定的历史对话和用户的最新输入生成响应,生成一个响应后,就立即返回,然后继续生成下一个响应。在每个步骤中,都会返回新的响应和更新后的历史对话。

  4. if count % 8 == 0: 这是一个简单的计数逻辑,每生成8个响应,就清屏并打印所有的历史对话。这样可以保证屏幕上不会有太多的文本。

  5. signal.signal(signal.SIGINT, signal_handler): 这行代码设置了一个信号处理函数,当接收到中断信号(如用户按下Ctrl+C)时,会调用signal_handler函数。这样用户可以通过按Ctrl+C来中断模型的流式对话。

  • 7
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI生成曾小健

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值