【山东大学项目实训】进度汇报16

进行了带有历史记录的问答链的修改优化

Chat_QA_chain_self

from langchain.chains import ConversationalRetrievalChain
from qa_chain.model_to_llm import model_to_llm
from qa_chain.get_vectordb import get_vectordb

class Chat_QA_chain_self:
    """"
    带历史记录的问答链  
    - model:调用的模型名称
    - temperature:温度系数,控制生成的随机性
    - top_k:返回检索的前k个相似文档
    - chat_history:历史记录,输入一个列表,默认是一个空列表
    - history_len:控制保留的最近 history_len 次对话
    - file_path:建库文件所在路径
    - persist_path:向量数据库持久化路径
    - api_key:智谱都需要传递的参数
    - embeddings:使用的embedding模型
    - embedding_key:使用的embedding模型的秘钥(智谱)
    """

    def __init__(self, model: str, temperature: float = 0.0, top_k: int = 4, chat_history: list = [],
                 file_path: str = None, persist_path: str = None, api_key: str = None, embedding="zhipuai",
                 embedding_key: str = None):
        # 初始化类实例的各个属性
        self.model = model
        self.temperature = temperature
        self.top_k = top_k
        self.chat_history = chat_history
        self.file_path = file_path
        self.persist_path = persist_path
        self.api_key = api_key
        self.embedding = embedding
        self.embedding_key = embedding_key

        # 获取向量数据库实例
        self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding, self.embedding_key)

    def clear_history(self):
        "清空历史记录"
        # 清空聊天历史记录
        return self.chat_history.clear()

    def change_history_length(self, history_len: int = 1):
        """
        保存指定对话轮次的历史记录
        输入参数:
        - history_len :控制保留的最近 history_len 次对话
        - chat_history:当前的历史对话记录
        输出:返回最近 history_len 次对话
        """
        # 保留最近的 history_len 次对话
        n = len(self.chat_history)
        return self.chat_history[n - history_len:]

    def answer(self, question: str = None, temperature=None, top_k=4):
        """"
        核心方法,调用问答链
        arguments: 
        - question:用户提问
        """
        # 如果问题为空,返回空字符串和当前聊天历史记录
        if len(question) == 0:
            return "", self.chat_history

        # 如果温度参数未设置,使用实例初始化时的温度
        if temperature == None:
            temperature = self.temperature

        # 获取语言模型实例
        llm = model_to_llm(self.model, temperature, self.api_key)

        # 获取检索器实例,使用向量数据库进行相似性检索
        retriever = self.vectordb.as_retriever(search_type="similarity",
                                               search_kwargs={'k': top_k})  # 默认similarity,k=4

        # 创建问答链实例
        qa = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=retriever
        )

        # 调用问答链,获取回答
        result = qa({"question": question, "chat_history": self.chat_history})  # result里有question、chat_history、answer
        answer = result['answer']
        
        # 将问题和回答添加到历史记录中
        self.chat_history.append((question, answer))

        # 返回更新后的历史记录
        return self.chat_history


1. 导入模块:
   - `ConversationalRetrievalChain`:用于创建带有检索功能的对话链。
   - `model_to_llm`:将模型名称转换为具体的语言模型实例。
   - `get_vectordb`:获取向量数据库实例,用于检索相似文档。

2. 类定义:
   - `Chat_QA_chain_self`:这是一个自定义的问答链类,包含了与模型交互、历史记录管理、向量检索等功能。

3. 初始化方法 `__init__`:
   - 初始化类的各种参数,包括模型名称、温度、检索文档数量、历史记录、文件路径、向量数据库路径、API 密钥、embedding 模型及其密钥。
   - 调用 `get_vectordb` 获取向量数据库实例。

4. `clear_history` 方法:
   - 清空当前的聊天历史记录。

5. `change_history_length` 方法:
   - 调整历史记录的长度,仅保留最近的 `history_len` 次对话。

6. `answer` 方法:
   - 处理用户提问,调用问答链获取回答,并更新历史记录。
   - 使用 `model_to_llm` 获取语言模型实例。
   - 使用向量数据库实例进行相似性检索。
   - 创建问答链实例并调用,获取回答并更新历史记录。

提供可以记录和管理历史对话的问答系统,支持基于相似性检索的文档查询,从而提高回答的准确性和相关性

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值