from langchain_community.vectorstores import Chroma # 从 langchain_community 导入 Chroma
from langchain.embeddings.base import Embeddings # 导入 Embeddings 接口
from langchain.chains import RetrievalQA
from langchain.llms import Ollama
import os # 导入 os 模块
import requests # 用于调用 Ollama 的 API
import chromadb # 导入 Chroma 客户端
# 启动 Chroma 服务(默认端口 8000)
# docker run -d -p 8000:8000 --name chroma chromadb/chroma:latest
# 使用 Ollama 的 nomic-embed-text 生成嵌入
class NomicEmbedText(Embeddings): # 实现 Embeddings 接口
def __init__(self, base_url="http://127.0.0.1:11434"):
self.base_url = base_url
def embed_query(self, text):
# 调用 Ollama 的 API 生成嵌入
response = requests.post(
f"{self.base_url}/api/embeddings",
json={"model": "nomic-embed-text:v1.5", "prompt": text}
)
if response.status_code != 200:
raise ValueError(f"嵌入生成失败: {response.text}")
return response.json()["embedding"]
def embed_documents(self, texts):
embeddings = []
for text in texts:
embedding = self.embed_query(text)
embeddings.append(embedding)
return embeddings
# 初始化 Chroma 客户端
def init_chroma_client(url="http://localhost:8000"):
client = chromadb.HttpClient(url)
return client
# 加载 Chroma 向量存储
def load_chroma_vector_store(collection_name="CompanyLaw"):
# 使用 nomic-embed-text 嵌入模型
embeddings = NomicEmbedText()
# 连接 Chroma 服务
client = init_chroma_client()
vector_store = Chroma(
client=client, # Chroma 客户端
collection_name=collection_name, # 集合名称
embedding_function=embeddings # 传入 Embeddings 对象
)
return vector_store
# 设置 RAG 管道
def setup_rag_pipeline(vector_store):
# 初始化 Ollama 模型,指定自定义地址和端口
llm = Ollama(
model="deepseek-r1:1.5b", # 模型名称
base_url="http://127.0.0.1:11434" # Ollama 服务的地址和端口
)
# 创建 RAG 管道
qa_pipeline = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff", # 简单拼接检索到的文档块
retriever=vector_store.as_retriever(search_kwargs={"k": 3}) # 检索前 3 个相关块
)
return qa_pipeline
# 查询 RAG 管道
def query_rag_pipeline(qa_pipeline, question):
result = qa_pipeline({"query": question})
return result["result"]
# 主函数
def main():
# 1. 加载 Chroma 向量存储
vector_store = load_chroma_vector_store()
# 2. 设置 RAG 管道
qa_pipeline = setup_rag_pipeline(vector_store)
# 3. 查询
question = "关于公司法律条款的解释"
answer = query_rag_pipeline(qa_pipeline, question)
print("回答:", answer)
if __name__ == "__main__":
main()
12-chroma_nomic_embed_text_rag
最新推荐文章于 2025-05-23 09:16:47 发布