llama模型集成-week6-2

构建聊天机器人的预处理和后处理机制

在构建高效的聊天机器人时,对输入的预处理和输出的后处理至关重要。这篇博客展示如何通过继承GeneratorBase类,实现对话历史的整合以及对模型输出的精细处理。

功能概述

主要包含三个功能:对输入的预处理、对输出的后处理以及实际的聊天生成函数。这个类是对话系统中的核心,负责处理用户输入,调用深度学习模型生成回复,并对模型的原始输出进行适当的格式化。

输入预处理(preprocess_inputs

预处理步骤的主要目的是构建适合模型理解的输入格式。在聊天机器人的上下文中,这通常意味着需要将当前的查询(用户输入)和对话历史整合成一个完整的提示字符串。

def preprocess_inputs(self, query: str, history=None):
    if history is None:
        history = []
    prompt = ""
    for record in history:
        prompt += f"""<s>:{record[0]}<eoh>\n:{record[1]}<eoa>\n"""
    if len(prompt) == 0:
        prompt += "<s>"
    prompt += f""":{query}<eoh>\n:"""
    return prompt, history

在这段代码中,history包含了一系列的对话记录,每一条记录是一个包含用户输入和系统回复的元组。这些记录被格式化为特定的标记(如<s><eoh><eoa>)之间的字符串,这些标记帮助模型区分不同部分的内容。这样的格式化对于保持对话上下文的连贯性至关重要。

输出后处理(post_process

一旦模型生成了响应,就需要对这些输出进行处理,以便它们可以被用户理解和接受。

def post_process(self, outputs, prompt_length, output_scores=False):
    if output_scores:
        score = outputs.scores[0]
        return score
    outputs = outputs.tolist()[0][prompt_length:]
    response = self.tokenizer.decode(outputs, skip_special_tokens=True)
    response = response.split("<eoa>")[0]
    return response

这里的后处理步骤首先检查是否需要输出得分(用于调试或模型评估),然后从模型输出中剪切掉用于生成的原始提示部分。之后,使用分词器将输出的令牌转换回文本格式,并移除任何特殊标记。

聊天函数(chat

聊天函数是整个系统的中枢,它将预处理和后处理串联起来,并管理与深度学习模型的交互。

@torch.no_grad()
def chat(self, query: str, history: List[Tuple[str, str]] = None, eos_token_id = (2, 103028), **kwargs):
    prompt, history = self.preprocess_inputs(query, history)
    inputs = self.build_tokens(prompt)
    output_scores = kwargs.get('output_scores', False)
    outputs = self.model.generate(**inputs, eos_token_id=eos_token_id, **kwargs)
    prompt_length = len(inputs["input_ids"][0])
    response = self.post_process(outputs, prompt_length, output_scores)
    return response, history

这个函数首先调用preprocess_inputs来准备输入,然后使用模型的generate方法生成文本。输出通过post_process处理后返回,与更新的历史记录一起返回给调用者。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值