为了仿照上述代码实现一个简化的FAISS存储和检索功能,利用LangChain中的 FAISS
和 Document
模块,下面是一个优化后的代码示例。该实现专注于FAISS的核心功能,包括将文档存储到FAISS向量存储中,并在检索时打印所有存储向量和查询的相似度得分。
simple_faiss_service_langchain.py
实现
import os
from langchain.vectorstores.faiss import FAISS
from langchain.schema import Document
from transformers import AutoTokenizer, AutoModel
import numpy as np
import torch
class SimpleFaissService:
def __init__(self, embedding_model: str = "m3e-base", distance_metric: str = "ip", storage_path: str = "./faiss_index"):
"""
初始化FAISS存储和检索功能。
:param embedding_model: 嵌入模型名称,例如 "m3e-base"。
:param distance_metric: 距离度量方式,支持 "ip"(内积)或 "l2"(欧氏距离)。
:param storage_path: FAISS索引的保存路径。
"""
self.tokenizer = AutoTokenizer.from_pretrained(embedding_model)
self.model = AutoModel.from_pretrained(embedding_model)
self.storage_path = storage_path
self.distance_metric = distance_metric
# 使用LangChain的FAISS模块
self.vector_store = FAISS.from_documents([], self.embed_text, normalize_L2=(distance_metric == "l2"))
def embed_text(self, texts: list) -> np.ndarray:
"""
将文本列表转换为嵌入向量。
:param texts: 文本列表。
:return: 嵌入向量数组。
"""
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state[:, 0, :]
embeddings = embeddings.cpu().numpy()
if self.distance_metric == "l2":
faiss.normalize_L2(embeddings)
return embeddings
def add_documents(self, docs: list):
"""
将文档添加到FAISS索引中。
:param docs: 文档列表。
"""
documents = [Document(page_content=doc) for doc in docs]
self.vector_store.add_documents(documents)
def search(self, query: str, top_k: int = 5):
"""
在FAISS索引中检索与查询最匹配的文档,并打印每个文档的相似度得分。
:param query: 查询文本。
:param top_k: 返回的匹配文档数量。
:return: 符合条件的匹配文档和相似度得分的列表。
"""
print(f"\n查询内容: {query}")
query_embedding = self.embed_text([query])
results = self.vector_store.similarity_search_with_score_by_vector(query_embedding[0], k=top_k)
print("\n所有文档的相似度得分:")
for doc, score in results:
print(f"文档: {doc.page_content}, 相似度得分: {score}")
return results
def save_index(self):
"""
保存FAISS索引到文件。
"""
if not os.path.exists(self.storage_path):
os.makedirs(self.storage_path)
self.vector_store.save_local(self.storage_path)
print(f"FAISS索引已保存到: {self.storage_path}")
def load_index(self):
"""
从文件加载FAISS索引。
"""
if os.path.exists(self.storage_path):
self.vector_store = FAISS.load_local(self.storage_path, self.embed_text)
print(f"FAISS索引已从 {self.storage_path} 加载")
else:
print("索引文件不存在,无法加载!")
# 使用示例
if __name__ == "__main__":
docs = ["你好,这是一个样本文档。", "今天的天气很好。", "这是关于AI的介绍。", "机器学习是人工智能的一个分支。"]
query = "什么是人工智能?"
# 初始化FAISS服务
faiss_service = SimpleFaissService(embedding_model="m3e-base", distance_metric="ip", storage_path="./faiss_index")
# 添加文档
faiss_service.add_documents(docs)
# 查询并打印相似度得分
results = faiss_service.search(query, top_k=3)
# 保存索引
faiss_service.save_index()
# 加载索引
# faiss_service.load_index()
代码说明
embed_text
:将文本转为嵌入向量,支持指定模型,返回numpy数组。add_documents
:将文档内容转为Document
对象并添加到FAISS索引中。search
:根据查询文本获取最相似的文档,并打印出所有文档的相似度得分。返回与查询匹配的top_k
文档及其得分。save_index
和load_index
:支持将FAISS索引保存到本地,并能从本地加载。
使用说明
- 添加文档:使用
add_documents
添加文档到FAISS索引。 - 查询检索:使用
search
根据查询文本检索匹配文档,打印每个存储文档与查询的相似度得分。 - 保存和加载索引:
save_index
将索引保存到指定文件夹;load_index
从该文件夹加载索引。
此实现基于LangChain的 FAISS
,以简化FAISS存储和检索,并实现打印所有数据的相似度分数。
knowlege_base_chat
from fastapi import Body, Request
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH,
EMBEDDING_MODEL)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
from server.reranker.reranker import LangchainReranker
from server.utils import embedding_device
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from langchain.docstore.document import Document
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(
SCORE_THRESHOLD,
description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
ge=0,
le=2
),
history: List[History] = Body(
[],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
embed_model: str = Body(EMBEDDING_MODEL, description="嵌入模型名称"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
None,
description="限制LLM生成Token数量,默认None代表模型最大值"
),
prompt_name: str = Body(
"default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
return_all_vectors: bool = Body(True, description="是否返回所有向量数据"),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(
query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
# 使用与知识库向量相同的嵌入方法,获取用户输入问题的向量
embed_func = EmbeddingsFunAdapter(embed_model) # 与知识库的向量方法一致
query_vector = embed_func.embed_query(query)
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs_data = await run_in_threadpool(search_docs,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold,
return_all_vectors=True)# 要求返回所有向量
similar_docs = docs_data['similar_docs']
all_vectors = docs_data.get('all_vectors', [])
similarity_scores = [doc.score for doc in similar_docs]
# 加入reranker
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=top_k,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)
print(similar_docs)
similar_docs = reranker_model.compress_documents(documents=similar_docs, query=query)
print("---------after rerank------------------")
print(similar_docs)
context = "\n".join([doc.page_content for doc in similar_docs])
if len(similar_docs) == 0: # 如果没有找到相关文档,使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
# 将 Document 对象转换为可序列化的字典
source_documents = [
{
"content": doc.page_content, # 提取文档内容
"metadata": doc.metadata, # 提取文档元数据
"source": f"[出处 {inum + 1}] [{doc.metadata.get('source', '未知来源')}]({request.base_url}knowledge_base/download_doc?{urlencode({'knowledge_base_name': knowledge_base_name, 'file_name': doc.metadata.get('source', '未知来源')})})"
}
for inum, doc in enumerate(similar_docs) if isinstance(doc, Document) # 确保只处理 Document 对象
]
# 处理当没有相关文档的情况
if len(source_documents) == 0:
source_documents.append(
{"content": "<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>"})
visualization_data = {
"query_vector": query_vector if isinstance(query_vector, list) else query_vector.tolist(), # 转换向量为列表
"all_vectors": [vector.tolist() if hasattr(vector, 'tolist') else vector for vector in all_vectors],
# 确保所有向量转为列表
"similarity_scores": similarity_scores
}
# 输出结果
if stream:
async for token in callback.aiter():
# 流式输出响应
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents, "visualization_data": visualization_data}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer,
"docs": source_documents,
"similarity_scores": similarity_scores,
"visualization_data": visualization_data
}, ensure_ascii=False)
await task
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))