项目实训11-命令行测试

控制台运行测试

装载适配器与模型控制台运行

首先在控制台设计了一个简单的对话系统,用来测试我训练好的模型的回答能力。是一个简单的命令行聊天应用,用于与经过微调的语言模型进行交互。

首先我需要准备并解析模型、数据、微调和生成的参数。将我训练好的GLM的适配器插入我们的模型,并在控制台运行一个持久化接口来不断测试回答我们的问题。这里用到了部分手动实现的utils包中的代码

加载模型

# 准备并解析模型、数据、微调和生成的参数。
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
# 加载预训练模型和分词器。
model, tokenizer = load_pretrained(model_args, finetuning_args)
# 初始化提示模板。
prompt_template = Template(data_args.prompt_template)
# 设置源前缀。
source_prefix = data_args.source_prefix if data_args.source_prefix else ""

用户回复

要实现用户回复的功能则完成如下的内容。

1. 输入处理和准备
python复制代码input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
  • 输入令牌化:
    • prompt_template.get_prompt(query, history, source_prefix):使用提示模板将用户的查询、历史对话和源前缀组合成模型的输入格式。
    • tokenizer:将上述组合的文本转换为模型可以理解的输入ID(张量形式)。
    • input_ids.to(model.device):将输入ID移至模型的计算设备(如GPU)。
2. 初始化文本流处理器
python
复制代码
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
  • 文本流处理器:
    • TextIteratorStreamer:初始化文本流处理器,用于逐步接收和处理模型生成的文本。
    • timeout=60.0:设置超时时间为60秒。
    • skip_prompt=True:跳过提示部分,只关注生成的回复。
    • skip_special_tokens=True:跳过特殊令牌。
3. 设置生成参数
python复制代码gen_kwargs = generating_args.to_dict()
gen_kwargs.update({
    "do_sample": True,
    "num_beams": 1,
    "input_ids": input_ids,
    "temperature": 0.5,
    "top_p": 0.7,
    "top_k": 50,
    "repetition_penalty": 1.2,
    "length_penalty": 1.5,
    "logits_processor": get_logits_processor(),
    "streamer": streamer,
})
gen_kwargs.pop("max_new_tokens", None)
gen_kwargs["max_length"] = 8000
gen_kwargs.pop("max_length", None)
gen_kwargs["max_new_tokens"] = 7200
  • 生成参数:
    • do_sample=True:启用采样生成。
    • num_beams=1:设置束搜索数为1。
    • temperature=0.5:设置温度参数,控制生成文本的多样性。
    • top_p=0.7:设置核采样的累积概率阈值。
    • top_k=50:设置前k个候选项的采样。
    • repetition_penalty=1.2:设置重复惩罚参数。
    • length_penalty=1.5:设置长度惩罚参数。
    • get_logits_processor():获取logits处理器。
    • streamer:将文本流处理器传递给生成参数。
    • max_lengthmax_new_tokens:设置生成文本的最大长度。
4. 生成回复
python复制代码thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
  • 多线程生成:
    • 创建一个线程,目标函数为 model.generate,并传入生成参数 gen_kwargs
    • 启动线程以生成回复。
5. 打印生成的回复
python复制代码print("Assistant: ", end="", flush=True)

response = ""
for new_text in streamer:
    print(new_text, end="", flush=True)
    response += new_text
print()
  • 实时打印:
    • 打印初始提示 Assistant:
    • 使用 for new_text in streamer 循环逐步接收并打印生成的文本。
    • 将每个新生成的文本片段追加到 response 中。
    • 最后打印完整回复。
6. 更新历史对话
python复制代码history = history + [(query, response)]
return []
  • 更新历史:
    • 将当前查询和回复以元组形式追加到历史对话列表 history 中。
    • 返回空列表(表示历史对话已更新)。

持续循环

实现命令行界面的主循环,用于接收用户输入、处理特定命令,并调用 predict_and_print 函数生成回复。

history = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
    try:
        query = input("\nUser: ")
    except UnicodeDecodeError:
        print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
        continue
    except Exception:
        raise
    if query.strip() == "exit":
        break
    if query.strip() == "clear":
        history = []
        print("History has been removed.")
        continue
    history = predict_and_print(query, [])

总代码

# coding=utf-8
# Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint


from utils import (
    Template,
    load_pretrained,
    prepare_infer_args,
    get_logits_processor
)
from threading import Thread
from transformers import TextIteratorStreamer


def main():

    model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
    model, tokenizer = load_pretrained(model_args, finetuning_args)

    prompt_template = Template(data_args.prompt_template)
    source_prefix = data_args.source_prefix if data_args.source_prefix else ""

    def predict_and_print(query, history: list) -> list:
        input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
        input_ids = input_ids.to(model.device)

        streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)

        gen_kwargs = generating_args.to_dict()
        # gen_kwargs.update({
        #     "input_ids": input_ids,
        #     "logits_processor": get_logits_processor(),
        #     "streamer": streamer
        # })
        gen_kwargs.update({
            "do_sample": True ,
            "num_beams": 1 ,
            "input_ids": input_ids,
            "temperature": 0.5, ##0.1
            "top_p": 0.7, ##0.4
            "top_k": 50,
            "repetition_penalty": 1.2,
            "length_penalty": 1.5,
            "logits_processor": get_logits_processor(),
            "streamer": streamer,
            
        })
        # if request.max_length:
        gen_kwargs.pop("max_new_tokens", None)
        gen_kwargs["max_length"] = 8000

        gen_kwargs.pop("max_length", None)
        gen_kwargs["max_new_tokens"] = 7200

        thread = Thread(target=model.generate, kwargs=gen_kwargs)
        thread.start()

        print("Assistant: ", end="", flush=True)

        response = ""
        for new_text in streamer:
            print(new_text, end="", flush=True)
            response += new_text
        print()

        history = history + [(query, response)]
        return []

    history = []
    print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")

    while True:
        try:
            query = input("\nUser: ")
        except UnicodeDecodeError:
            print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
            continue
        except Exception:
            raise

        if query.strip() == "exit":
            break

        if query.strip() == "clear":
            history = []
            print("History has been removed.")
            continue

        history = predict_and_print(query, [])


if __name__ == "__main__":
    main()

效果测试

总结(完善)

进行两个总结测试,测试该模型的总结能力还算可以,结果较为正确

在这里插入图片描述
在这里插入图片描述

抽取(完善)

进行两个抽取测试,测试发现该模型的抽取能力水平也较为可以。

在这里插入图片描述

在这里插入图片描述

推理(能力不足)

对其进行推理部分的能力测试,发现其内容并不足以正确的实现推理的能力,仍需要更大更专精的数据集来进行,并且该部分的功能由其他成员完成。

在这里插入图片描述

  • 8
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值