进行了问答链的优化
优化说明
- 模块级别的常量定义:将默认模板
DEFAULT_TEMPLATE_RQ
移到类外,便于模块级别管理。 - 封装初始化函数:将向量数据库和大语言模型的初始化代码封装到单独的函数中,提高代码的可读性和复用性。
- 添加异常处理:在初始化和问答方法中添加异常处理,确保在遇到错误时可以提供明确的提示。
- 参数校验:在
answer
方法中增加对 question
参数的校验,确保问题不为空。 - 提高可读性:通过增加注释和文档字符串,解释函数和类的作用,提高代码的可读性和可维护性。
from langchain.prompts import PromptTemplate
# 导入PromptTemplate类,用于创建问题模板
from langchain.chains import RetrievalQA
# 导入RetrievalQA类,用于执行检索问答
from langchain.vectorstores import Chroma
# 导入Chroma类,用于处理向量存储
import sys
# 导入sys模块,用于操作系统相关功能
sys.path.append("../")
# 添加上级目录到模块搜索路径
from qa_chain.model_to_llm import model_to_llm
# 导入model_to_llm函数,用于将模型转换为大语言模型(LLM)
from qa_chain.get_vectordb import get_vectordb
# 导入get_vectordb函数,用于获取向量数据库
#定义 QA_chain_self 类:
#该类用于创建和管理一个问答链系统,不带历史记录,类的描述文档说明了各个参数的用途。
class QA_chain_self():
# 类描述
"""
不带历史记录的问答链
- model:调用的模型名称
- temperature:温度系数,控制生成的随机性
- top_k:返回检索的前k个相似文档
- file_path:建库文件所在路径
- persist_path:向量数据库持久化路径
- api_key:所有模型都需要
- embeddings:使用的embedding模型
- embedding_key:使用的embedding模型的秘钥(智谱或者OpenAI)
- template:可以自定义提示模板,没有输入则使用默认的提示模板default_template_rq
"""
# 默认的提示模板,用于构建问答的输入
default_template_rq = """
使用以下上下文来回答最后的问题。如果你不知道答案,就说你不知道,不要试图编造答案。最多使用三句话。尽量使答案简明扼要。总是在回答的最后说“谢谢你的提问!”。
{context}
问题: {question}
有用的回答:"""
# 构造函数,初始化类实例
#初始化向量数据库和LLM
#初始化 PromptTemplate 和 RetrievalQA 类
def __init__(self, model: str, temperature: float = 0.0, top_k: int = 4, file_path: str = None,
persist_path: str = None, api_key: str = None,
embedding="zhipuai", embedding_key=None, template=default_template_rq):
# 类属性初始化
self.model = model
self.temperature = temperature
self.top_k = top_k
self.file_path = file_path
self.persist_path = persist_path
self.api_key = api_key
self.embedding = embedding
self.embedding_key = embedding_key
self.template = template
# 初始化向量数据库和大语言模型
self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding, self.embedding_key)
self.llm = model_to_llm(self.model, self.temperature, self.api_key)
# 初始化PromptTemplate和RetrievalQA类
self.QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"], template=self.template)
self.retriever = self.vectordb.as_retriever(search_type="similarity", search_kwargs={'k': self.top_k})
self.qa_chain = RetrievalQA.from_chain_type(llm=self.llm, retriever=self.retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": self.QA_CHAIN_PROMPT})
# 提供问答功能的方法,根据用户问题调用问答链并返回结果
def answer(self, question: str = None, temperature=None, top_k=4):
"""
核心方法,调用问答链
arguments:
- question:用户提问
"""
# 检查问题是否为空
if len(question) == 0:
return ""
# 如果没有指定温度或top_k参数,则使用初始化时的值
if temperature is None:
temperature = self.temperature
if top_k is None:
top_k = self.top_k
# 调用问答链并返回结果
result = self.qa_chain({"query": question, "temperature": temperature, "top_k": top_k})
return result["result"]