向量数据库最终版修改好的代码
import os import sys import time import tempfile from dotenv import load_dotenv, find_dotenv from embedding.call_embedding import get_embedding from langchain.document_loaders import ( UnstructuredFileLoader, UnstructuredMarkdownLoader, UnstructuredWordDocumentLoader, PyMuPDFLoader ) from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma # 数据库文件存储路径 DEFAULT_DB_PATH = "../database/knowledge_db" # 存储向量数据库生成的向量数据 DEFAULT_PERSIST_PATH = "../database/vector_data_base" def get_files(dir_path): """ 递归获取指定目录下的所有文件路径。 参数: - dir_path: 目标目录路径。 返回: - file_list: 目录中所有文件的路径列表。 """ file_list = [] for filepath, dirnames, filenames in os.walk(dir_path): for filename in filenames: file_list.append(os.path.join(filepath, filename)) return file_list def file_loader(file, loaders): """ 根据文件类型加载文档,并添加到加载器列表中。 参数: - file: 文件路径或临时文件对象。 - loaders: 文档加载器列表。 """ if isinstance(file, tempfile._TemporaryFileWrapper): file = file.name if not os.path.isfile(file): # 如果是目录,则递归加载目录下的所有文件 [file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)] return #根据文件类型使用不同的方法获取文件 file_type = file.split('.')[-1] if file_type == 'pdf': loaders.append(PyMuPDFLoader(file)) elif file_type == 'md': loaders.append(UnstructuredMarkdownLoader(file)) elif file_type == 'txt': loaders.append(UnstructuredFileLoader(file)) elif file_type == 'docx': loaders.append(UnstructuredWordDocumentLoader(file)) return def create_db_info(files=DEFAULT_DB_PATH, embeddings="zhipuai", persist_directory=DEFAULT_PERSIST_PATH): """ 创建知识库,加载文档并生成嵌入向量数据库。 参数: - files: 要处理的文件路径或文件路径列表。 - embeddings: 用于生成嵌入向量的模型名称或对象。 - persist_directory: 向量数据库的持久化目录路径。 返回: - vectordb: 创建的向量数据库对象。 """ vectordb = create_db(files, persist_directory, embeddings) return vectordb def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="zhipuai"): """ 加载文档,切分文本,生成嵌入向量,创建向量数据库。 参数: - files: 要处理的文件路径或文件路径列表。 - persist_directory: 向量数据库的持久化目录路径。 - embeddings: 用于生成嵌入向量的模型名称或对象。 返回: - vectordb: 创建的向量数据库对象。 """ if files is None: files = DEFAULT_DB_PATH if isinstance(files, str): files = [files] loaders = [] [file_loader(file, loaders) for file in files] docs = [] for loader in loaders: if loader is not None: try: docs.extend(loader.load()) except Exception as e: print(f"Failed to load document: {e}") # 切分文档 text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=150) split_docs = text_splitter.split_documents(docs[:10]) # 定义持久化路径 if isinstance(persist_directory, str): persist_directory = [persist_directory] for directory in persist_directory: if not os.path.exists(directory): os.makedirs(directory) if isinstance(embeddings, str): embeddings = get_embedding(embedding=embeddings) # 加载数据库 vectordb = Chroma.from_documents( documents=split_docs, embedding=embeddings, persist_directory=persist_directory ) vectordb.persist() return vectordb def presit_knowledge_db(vectordb): """ 将向量数据库持久化保存。 参数: - vectordb: 要持久化的向量数据库对象。 """ vectordb.persist() def load_knowledge_db(path, embeddings): """ 加载指定路径的向量数据库。 参数: - path: 向量数据库的路径。 - embeddings: 使用的嵌入向量模型。 返回: - vectordb: 加载的向量数据库对象。 """ vectordb = Chroma( persist_directory=path, embedding_function=embeddings ) return vectordb if __name__ == "__main__": create_db(embeddings="zhipuai")