构建聊天机器人的预处理和后处理机制
在构建高效的聊天机器人时,对输入的预处理和输出的后处理至关重要。这篇博客展示如何通过继承GeneratorBase
类,实现对话历史的整合以及对模型输出的精细处理。
功能概述
主要包含三个功能:对输入的预处理、对输出的后处理以及实际的聊天生成函数。这个类是对话系统中的核心,负责处理用户输入,调用深度学习模型生成回复,并对模型的原始输出进行适当的格式化。
输入预处理(preprocess_inputs
)
预处理步骤的主要目的是构建适合模型理解的输入格式。在聊天机器人的上下文中,这通常意味着需要将当前的查询(用户输入)和对话历史整合成一个完整的提示字符串。
def preprocess_inputs(self, query: str, history=None):
if history is None:
history = []
prompt = ""
for record in history:
prompt += f"""<s>:{record[0]}<eoh>\n:{record[1]}<eoa>\n"""
if len(prompt) == 0:
prompt += "<s>"
prompt += f""":{query}<eoh>\n:"""
return prompt, history
在这段代码中,history
包含了一系列的对话记录,每一条记录是一个包含用户输入和系统回复的元组。这些记录被格式化为特定的标记(如<s>
,<eoh>
,<eoa>
)之间的字符串,这些标记帮助模型区分不同部分的内容。这样的格式化对于保持对话上下文的连贯性至关重要。
输出后处理(post_process
)
一旦模型生成了响应,就需要对这些输出进行处理,以便它们可以被用户理解和接受。
def post_process(self, outputs, prompt_length, output_scores=False):
if output_scores:
score = outputs.scores[0]
return score
outputs = outputs.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split("<eoa>")[0]
return response
这里的后处理步骤首先检查是否需要输出得分(用于调试或模型评估),然后从模型输出中剪切掉用于生成的原始提示部分。之后,使用分词器将输出的令牌转换回文本格式,并移除任何特殊标记。
聊天函数(chat
)
聊天函数是整个系统的中枢,它将预处理和后处理串联起来,并管理与深度学习模型的交互。
@torch.no_grad()
def chat(self, query: str, history: List[Tuple[str, str]] = None, eos_token_id = (2, 103028), **kwargs):
prompt, history = self.preprocess_inputs(query, history)
inputs = self.build_tokens(prompt)
output_scores = kwargs.get('output_scores', False)
outputs = self.model.generate(**inputs, eos_token_id=eos_token_id, **kwargs)
prompt_length = len(inputs["input_ids"][0])
response = self.post_process(outputs, prompt_length, output_scores)
return response, history
这个函数首先调用preprocess_inputs
来准备输入,然后使用模型的generate
方法生成文本。输出通过post_process
处理后返回,与更新的历史记录一起返回给调用者。