langchain实现知识库

为了仿照上述代码实现一个简化的FAISS存储和检索功能,利用LangChain中的 FAISSDocument 模块,下面是一个优化后的代码示例。该实现专注于FAISS的核心功能,包括将文档存储到FAISS向量存储中,并在检索时打印所有存储向量和查询的相似度得分。

simple_faiss_service_langchain.py 实现

import os
from langchain.vectorstores.faiss import FAISS
from langchain.schema import Document
from transformers import AutoTokenizer, AutoModel
import numpy as np
import torch

class SimpleFaissService:
    def __init__(self, embedding_model: str = "m3e-base", distance_metric: str = "ip", storage_path: str = "./faiss_index"):
        """
        初始化FAISS存储和检索功能。
        
        :param embedding_model: 嵌入模型名称,例如 "m3e-base"。
        :param distance_metric: 距离度量方式,支持 "ip"(内积)或 "l2"(欧氏距离)。
        :param storage_path: FAISS索引的保存路径。
        """
        self.tokenizer = AutoTokenizer.from_pretrained(embedding_model)
        self.model = AutoModel.from_pretrained(embedding_model)
        self.storage_path = storage_path
        self.distance_metric = distance_metric

        # 使用LangChain的FAISS模块
        self.vector_store = FAISS.from_documents([], self.embed_text, normalize_L2=(distance_metric == "l2"))

    def embed_text(self, texts: list) -> np.ndarray:
        """
        将文本列表转换为嵌入向量。
        
        :param texts: 文本列表。
        :return: 嵌入向量数组。
        """
        inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
            embeddings = outputs.last_hidden_state[:, 0, :]
            embeddings = embeddings.cpu().numpy()
        
        if self.distance_metric == "l2":
            faiss.normalize_L2(embeddings)
        return embeddings

    def add_documents(self, docs: list):
        """
        将文档添加到FAISS索引中。
        
        :param docs: 文档列表。
        """
        documents = [Document(page_content=doc) for doc in docs]
        self.vector_store.add_documents(documents)

    def search(self, query: str, top_k: int = 5):
        """
        在FAISS索引中检索与查询最匹配的文档,并打印每个文档的相似度得分。
        
        :param query: 查询文本。
        :param top_k: 返回的匹配文档数量。
        :return: 符合条件的匹配文档和相似度得分的列表。
        """
        print(f"\n查询内容: {query}")
        query_embedding = self.embed_text([query])
        results = self.vector_store.similarity_search_with_score_by_vector(query_embedding[0], k=top_k)

        print("\n所有文档的相似度得分:")
        for doc, score in results:
            print(f"文档: {doc.page_content}, 相似度得分: {score}")

        return results

    def save_index(self):
        """
        保存FAISS索引到文件。
        """
        if not os.path.exists(self.storage_path):
            os.makedirs(self.storage_path)
        self.vector_store.save_local(self.storage_path)
        print(f"FAISS索引已保存到: {self.storage_path}")

    def load_index(self):
        """
        从文件加载FAISS索引。
        """
        if os.path.exists(self.storage_path):
            self.vector_store = FAISS.load_local(self.storage_path, self.embed_text)
            print(f"FAISS索引已从 {self.storage_path} 加载")
        else:
            print("索引文件不存在,无法加载!")

# 使用示例
if __name__ == "__main__":
    docs = ["你好,这是一个样本文档。", "今天的天气很好。", "这是关于AI的介绍。", "机器学习是人工智能的一个分支。"]
    query = "什么是人工智能?"

    # 初始化FAISS服务
    faiss_service = SimpleFaissService(embedding_model="m3e-base", distance_metric="ip", storage_path="./faiss_index")

    # 添加文档
    faiss_service.add_documents(docs)

    # 查询并打印相似度得分
    results = faiss_service.search(query, top_k=3)

    # 保存索引
    faiss_service.save_index()

    # 加载索引
    # faiss_service.load_index()

代码说明

  • embed_text:将文本转为嵌入向量,支持指定模型,返回numpy数组。
  • add_documents:将文档内容转为 Document 对象并添加到FAISS索引中。
  • search:根据查询文本获取最相似的文档,并打印出所有文档的相似度得分。返回与查询匹配的 top_k 文档及其得分。
  • save_indexload_index:支持将FAISS索引保存到本地,并能从本地加载。

使用说明

  1. 添加文档:使用 add_documents 添加文档到FAISS索引。
  2. 查询检索:使用 search 根据查询文本检索匹配文档,打印每个存储文档与查询的相似度得分。
  3. 保存和加载索引save_index 将索引保存到指定文件夹;load_index 从该文件夹加载索引。

此实现基于LangChain的 FAISS,以简化FAISS存储和检索,并实现打印所有数据的相似度分数。

knowlege_base_chat

from fastapi import Body, Request
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS,
                     VECTOR_SEARCH_TOP_K,
                     SCORE_THRESHOLD,
                     TEMPERATURE,
                     USE_RERANKER,
                     RERANKER_MODEL,
                     RERANKER_MAX_LENGTH,
                     MODEL_PATH,
                     EMBEDDING_MODEL)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
from server.reranker.reranker import LangchainReranker
from server.utils import embedding_device
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from langchain.docstore.document import Document

async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
                              knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
                              top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
                              score_threshold: float = Body(
                                  SCORE_THRESHOLD,
                                  description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
                                  ge=0,
                                  le=2
                              ),
                              history: List[History] = Body(
                                  [],
                                  description="历史对话",
                                  examples=[[
                                      {"role": "user",
                                       "content": "我们来玩成语接龙,我先来,生龙活虎"},
                                      {"role": "assistant",
                                       "content": "虎头虎脑"}]]
                              ),
                              stream: bool = Body(False, description="流式输出"),
                              model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
                              embed_model: str = Body(EMBEDDING_MODEL, description="嵌入模型名称"),
                              temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
                              max_tokens: Optional[int] = Body(
                                  None,
                                  description="限制LLM生成Token数量,默认None代表模型最大值"
                              ),
                              prompt_name: str = Body(
                                  "default",
                                  description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
                              ),
                              return_all_vectors: bool = Body(True, description="是否返回所有向量数据"),
                              request: Request = None,
                              ):
    kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
    if kb is None:
        return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")

    history = [History.from_data(h) for h in history]

    async def knowledge_base_chat_iterator(
            query: str,
            top_k: int,
            history: Optional[List[History]],
            model_name: str = model_name,
            prompt_name: str = prompt_name,
    ) -> AsyncIterable[str]:
        nonlocal max_tokens
        callback = AsyncIteratorCallbackHandler()
        if isinstance(max_tokens, int) and max_tokens <= 0:
            max_tokens = None

        # 使用与知识库向量相同的嵌入方法,获取用户输入问题的向量
        embed_func = EmbeddingsFunAdapter(embed_model)  # 与知识库的向量方法一致
        query_vector = embed_func.embed_query(query)

        model = get_ChatOpenAI(
            model_name=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            callbacks=[callback],
        )
        docs_data  = await run_in_threadpool(search_docs,
                                       query=query,
                                       knowledge_base_name=knowledge_base_name,
                                       top_k=top_k,
                                       score_threshold=score_threshold,
                                       return_all_vectors=True)# 要求返回所有向量
        similar_docs = docs_data['similar_docs']
        all_vectors = docs_data.get('all_vectors', [])
        similarity_scores = [doc.score for doc in similar_docs]

        # 加入reranker
        if USE_RERANKER:
            reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
            print("-----------------model path------------------")
            print(reranker_model_path)
            reranker_model = LangchainReranker(top_n=top_k,
                                            device=embedding_device(),
                                            max_length=RERANKER_MAX_LENGTH,
                                            model_name_or_path=reranker_model_path
                                            )
            print(similar_docs)
            similar_docs = reranker_model.compress_documents(documents=similar_docs, query=query)
            print("---------after rerank------------------")
            print(similar_docs)

        context = "\n".join([doc.page_content for doc in similar_docs])

        if len(similar_docs) == 0:  # 如果没有找到相关文档,使用empty模板
            prompt_template = get_prompt_template("knowledge_base_chat", "empty")
        else:
            prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
        input_msg = History(role="user", content=prompt_template).to_msg_template(False)
        chat_prompt = ChatPromptTemplate.from_messages(
            [i.to_msg_template() for i in history] + [input_msg])

        chain = LLMChain(prompt=chat_prompt, llm=model)

        # Begin a task that runs in the background.
        task = asyncio.create_task(wrap_done(
            chain.acall({"context": context, "question": query}),
            callback.done),
        )

        # 将 Document 对象转换为可序列化的字典
        source_documents = [
            {
                "content": doc.page_content,  # 提取文档内容
                "metadata": doc.metadata,  # 提取文档元数据
                "source": f"[出处 {inum + 1}] [{doc.metadata.get('source', '未知来源')}]({request.base_url}knowledge_base/download_doc?{urlencode({'knowledge_base_name': knowledge_base_name, 'file_name': doc.metadata.get('source', '未知来源')})})"
            }
            for inum, doc in enumerate(similar_docs) if isinstance(doc, Document)  # 确保只处理 Document 对象
        ]

        # 处理当没有相关文档的情况
        if len(source_documents) == 0:
            source_documents.append(
                {"content": "<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>"})

        visualization_data = {
            "query_vector": query_vector if isinstance(query_vector, list) else query_vector.tolist(),  # 转换向量为列表
            "all_vectors": [vector.tolist() if hasattr(vector, 'tolist') else vector for vector in all_vectors],
            # 确保所有向量转为列表
            "similarity_scores": similarity_scores
        }

        # 输出结果
        if stream:
            async for token in callback.aiter():
                # 流式输出响应
                yield json.dumps({"answer": token}, ensure_ascii=False)
            yield json.dumps({"docs": source_documents, "visualization_data": visualization_data}, ensure_ascii=False)
        else:
            answer = ""
            async for token in callback.aiter():
                answer += token
            yield json.dumps({"answer": answer,
                              "docs": source_documents,
                              "similarity_scores": similarity_scores,
                              "visualization_data": visualization_data
                              }, ensure_ascii=False)
        await task

    return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值