langchain-chatchat 0.2.x版本知识库对话模块增加用户反馈功能

当我成功在服务器上部署langchain-chatchat v0.2.9后,我想通过用户给知识库的正确反馈来更新知识库内容从而实现知识库的自动更新优化,但是我发现这个版本好像只对llm对话支持用户反馈功能,而现今该项目的v0.2.x版本已经不再维护更新,于是乎我打算自己动手丰衣足食。

话不多说我们直接开始

首先我们找到项目的前端页面代码,路径为

Langchain_Chatchat_master/webui_pages/dialogue/dialogue.py

可以找到下面的内容

以上图片中的内容对应了llm对话模块的一些详情,很明显feedback_kwargs作为参数被传给了chat_box类里的一个函数show_feedback,而且可以看到该函数下还传递了一个参数kwargs,其中包含了一个叫message_id的键,让我们再看看知识库对话模块的详情,如下:

这两个模块一比较就会发现,在llm中使用了chat_box.show_feedback()这个函数增加了反馈功能,而知识库对话模块中没有使用,而要使用chat_box.show_feedback()进行反馈就需要一个message_id和history_index(这个参数无关紧要),message_id到底是什么呢,我查看了项目下保存的dp数据库文件,由此肯定message_id是用户当前所对应的对话聊天框的编号,因为用户要对某个回答进行反馈时得根据一个唯一的编号找到对应的回答,但这个其实不重要,重要的是我将知识库对话接口的响应内容打印出来后发现并没有message_id这个值,而llm对话接口的响应内容中有这个值(可以打印出api.knowledge_base_chat和api.chat_chat这两个接口的返回结果比对),所以有同学将llm对话模块的内容直接复制替换掉知识库对话模块的内容是行不通的。

我们可以进入下面的文件查看api.knowledge_base_chat和api.chat_chat这两个接口函数是怎么定义的,路径如下:

# 对应知识库对话接口
Langchain_Chatchat_master/server/chat/knowledge_base_chat.py
# 对应llm对话接口
Langchain_Chatchat_master/server/chat/chat.py

chat.py的返回结果如下,可以看到带了message_id

knowledge_base_chat.py的返回结果如下:

并没有看到message_id,因此核心的要点就是通过修改知识库对话接口的返回内容,自己给它加上message_id,再获取它传入chat_box.show_feedback()中即可。

下面我们开始修改接口,直接附上代码,将其替换掉原来的knowledge_base_chat.py内容即可

from fastapi import Body, Request
import sys
sys.path.append('configs')
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS, 
                     VECTOR_SEARCH_TOP_K, 
                     SCORE_THRESHOLD, 
                     TEMPERATURE,
                     USE_RERANKER,
                     RERANKER_MODEL,
                     RERANKER_MAX_LENGTH,
                     MODEL_PATH,
                     HISTORY_LEN)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
from server.reranker.reranker import LangchainReranker
from server.utils import embedding_device
from server.db.repository import add_message_to_db
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler


async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
                              conversation_id: str = Body("", description="对话框ID"),
                              knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
                              top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
                              score_threshold: float = Body(
                                  SCORE_THRESHOLD,
                                  description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
                                  ge=0,
                                  le=2
                              ),
                            #   history_len: int = Body(HISTORY_LEN, description="从数据库中取历史消息的数量"),
                              history: List[History] = Body(
                                  [],
                                  description="历史对话",
                                  examples=[[
                                      {"role": "user",
                                       "content": "我们来玩成语接龙,我先来,生龙活虎"},
                                      {"role": "assistant",
                                       "content": "虎头虎脑"}]]
                              ),
                              stream: bool = Body(False, description="流式输出"),
                              model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
                              temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
                              max_tokens: Optional[int] = Body(
                                  None,
                                  description="限制LLM生成Token数量,默认None代表模型最大值"
                              ),
                              prompt_name: str = Body(
                                  "default",
                                  description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
                              ),
                              request: Request = None,
                              ):
    kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
    if kb is None:
        return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")

    history = [History.from_data(h) for h in history]

    async def knowledge_base_chat_iterator(
            query: str,
            top_k: int,
            # history_len: int,
            history: Optional[List[History]],
            conversation_id: str = conversation_id,
            model_name: str = model_name,
            prompt_name: str = prompt_name,
    ) -> AsyncIterable[str]:
        nonlocal max_tokens
        callback = AsyncIteratorCallbackHandler()
        callbacks = [callback]
        memory = None

        message_id = add_message_to_db(chat_type="知识库对话", query=query, conversation_id=conversation_id)
        conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
                                                            chat_type="知识库对话",
                                                            query=query)
        callbacks.append(conversation_callback)        
        if isinstance(max_tokens, int) and max_tokens <= 0:
            max_tokens = None

        model = get_ChatOpenAI(
            model_name=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            callbacks=callbacks,
        )
        docs = await run_in_threadpool(search_docs,
                                       query=query,
                                       knowledge_base_name=knowledge_base_name,
                                       top_k=top_k,
                                       score_threshold=score_threshold)

        # 加入reranker
        if USE_RERANKER:
            reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"model/bge-reranker-large")
            print("-----------------model path------------------")
            print(reranker_model_path)
            reranker_model = LangchainReranker(top_n=top_k,
                                            device=embedding_device(),
                                            max_length=RERANKER_MAX_LENGTH,
                                            model_name_or_path=reranker_model_path
                                            )
            print(docs)
            docs = reranker_model.compress_documents(documents=docs,
                                                     query=query)
            print("---------after rerank------------------")
            print(docs)
        context = "\n".join([doc.page_content for doc in docs])

        if len(docs) == 0:  # 如果没有找到相关文档,使用empty模板
            prompt_template = get_prompt_template("knowledge_base_chat", "empty")
        else:
            prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
        input_msg = History(role="user", content=prompt_template).to_msg_template(False)
        chat_prompt = ChatPromptTemplate.from_messages(
            [i.to_msg_template() for i in history] + [input_msg])
        chain = LLMChain(prompt=chat_prompt, llm=model)

        # # 如果要使用历史对话
        # if conversation_id and history_len > 0: # 前端要求从数据库取历史消息
        #     # 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
        #     prompt = get_prompt_template("knowledge_base_chat", "with_history")
        #     chat_prompt = PromptTemplate.from_template(prompt)
        #     # 根据conversation_id 获取message 列表进而拼凑 memory
        #     memory = ConversationBufferDBMemory(conversation_id=conversation_id,
        #                                         llm=model,
        #                                         message_limit=history_len)
        #     chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
        

        # Begin a task that runs in the background.
        task = asyncio.create_task(wrap_done(
            chain.acall({"context": context, "question": query}),
            callback.done),
        )

        source_documents = []
        for inum, doc in enumerate(docs):
            filename = doc.metadata.get("source")
            parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
            base_url = request.base_url
            url = f"{base_url}knowledge_base/download_doc?" + parameters
            text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
            source_documents.append(text)

        if len(source_documents) == 0:  # 没有找到相关文档
            source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")

        if stream:
            async for token in callback.aiter():
                # Use server-sent-events to stream the response
                yield json.dumps({"answer": token, "message_id":message_id}, ensure_ascii=False)
            yield json.dumps({"docs": source_documents}, ensure_ascii=False)
            # yield json.dumps({"docs": source_documents,"message_id":message_id}, ensure_ascii=False)
        else:
            answer = ""
            async for token in callback.aiter():
                answer += token
            yield json.dumps({"answer": answer,
                              "docs": source_documents,
                              "message_id":message_id},
                             ensure_ascii=False)
        await task

    return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history, conversation_id, model_name, prompt_name))

if __name__ == '__main__':
    import asyncio

修改完这个文件后,还需要对项目中封装api接口的代码文件进行修改,因为往函数中多加了一些参数,路径如下:

Langchain_Chatchat_master/webui_pages/utils.py

找到这个文件里的知识库接口定义函数,将下面的代码替换掉原来的

def knowledge_base_chat(
    self,
    query: str,
    knowledge_base_name: str,
    conversation_id: str,
    top_k: int = VECTOR_SEARCH_TOP_K,
    score_threshold: float = SCORE_THRESHOLD,
    history: List[Dict] = [],
    stream: bool = True,
    model: str = LLM_MODELS[0],
    temperature: float = TEMPERATURE,
    max_tokens: int = None,
    prompt_name: str = "default",
):
    '''
    对应api.py/chat/knowledge_base_chat接口
    '''
    data = {
        "query": query,
        "conversation_id":conversation_id,
        "knowledge_base_name": knowledge_base_name,
        "top_k": top_k,
        "score_threshold": score_threshold,
        "history": history,
        "stream": stream,
        "model_name": model,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "prompt_name": prompt_name,
    }

    # print(f"received input message:")
    # pprint(data)

    response = self.post(
        "/chat/knowledge_base_chat",
        json=data,
        stream=True,
    )
    return self._httpx_stream2generator(response, as_json=True)

定义完了新的知识库的返回结果和api接口,最后一步只需要修改前端页面的调用代码即可,回到文章开头提到的前端页面代码下,找到知识库模板,用以下代码替换掉原来的就可以了

elif dialogue_mode == "知识库问答":
    chat_box.ai_say([
        f"正在查询知识库 `{selected_kb}` ...",
        Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
    ])
    text = ""
    message_id = ""
    for d in api.knowledge_base_chat(prompt1,
                                    conversation_id=conversation_id,
                                    knowledge_base_name=selected_kb,
                                    top_k=kb_top_k,
                                    score_threshold=score_threshold,
                                    # history_len=history_len,
                                    history=history,
                                    model=llm_model,
                                    prompt_name=prompt_template_name,
                                    temperature=temperature):
        print(d)
        if error_msg := check_error_msg(d):  # check whether error occured
            st.error(error_msg)
        elif chunk := d.get("answer"):
            text += chunk
            chat_box.update_msg(text, element_index=0)
            message_id = d.get("message_id","")              
    metadata = {
        "message_id": message_id,
        }
    chat_box.update_msg(text, element_index=0, streaming=False)
    chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
    chat_box.show_feedback(**feedback_kwargs,
        key=message_id,
        on_submit=on_feedback,
        kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})

以上步骤完成后,打印知识库问答返回的结果,可以看到包含了message_id,而知识库对话的反馈功能也成功加上了

这个功能是其实两个月前已经实现了,本来当时想写的,但是后来一直在忙其它方向就给忘了,所以以上的一些操作都是根据我当时的部署记录总结的,如果存在什么问题可以在评论区留言,咱们一起探讨。

  • 24
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值