对向量数据库的代码进行了优化和调整
1. 模块导入优化
改进前:
import os
import sys
import time
import tempfile
from dotenv import load_dotenv, find_dotenv
from embedding.call_embedding import get_embedding
from langchain.document_loaders import UnstructuredFileLoader
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.document_loaders import UnstructuredWordDocumentLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyMuPDFLoader
from langchain.vectorstores import Chroma
改进后
import os
import sys
import time
import tempfile
from dotenv import load_dotenv, find_dotenv
from embedding.call_embedding import get_embedding
from langchain.document_loaders import (
UnstructuredFileLoader, UnstructuredMarkdownLoader,
UnstructuredWordDocumentLoader, PyMuPDFLoader
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
- 改进后的代码将所有的模块导入放在了一起,使用了更加整洁和规范的格式。
- 将相同包内的导入语句合并在一起,避免了重复导入的情况,提高了代码的可读性和维护性。
2. 异常处理的增加
改进后的 create_db
函数
def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="zhipuai"): """ 加载文档,切分文本,生成嵌入向量,创建向量数据库。 参数: - files: 要处理的文件路径或文件路径列表。 - persist_directory: 向量数据库的持久化目录路径。 - embeddings: 用于生成嵌入向量的模型名称或对象。 返回: - vectordb: 创建的向量数据库对象。 """ if files is None: files = DEFAULT_DB_PATH if isinstance(files, str): files = [files] loaders = [] [file_loader(file, loaders) for file in files] docs = [] for loader in loaders: if loader is not None: try: docs.extend(loader.load()) except Exception as e: print(f"Failed to load document: {e}") # 切分文档 text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=150) split_docs = text_splitter.split_documents(docs[:10]) # 定义持久化路径 if isinstance(persist_directory, str): persist_directory = [persist_directory] for directory in persist_directory: if not os.path.exists(directory): os.makedirs(directory) if isinstance(embeddings, str): embeddings = get_embedding(embedding=embeddings) # 加载数据库 vectordb = Chroma.from_documents( documents=split_docs, embedding=embeddings, persist_directory=persist_directory ) vectordb.persist() return vectordb
- 在
create_db
函数中,添加了异常处理机制,用try-except
块捕获加载文档时可能出现的异常。 - 如果某个文档加载失败,程序会打印出错信息并继续处理其他文档,而不是直接中断程序执行。
3. 路径处理的优化
改进后的 create_db
函数
if isinstance(persist_directory, str):
persist_directory = [persist_directory]
for directory in persist_directory:
if not os.path.exists(directory):
os.makedirs(directory)
- 对
persist_directory
的处理进行了优化,确保即使传入单个字符串路径,也能正确处理为列表形式。 - 在创建向量数据库之前,先检查目录是否存在,如果不存在则创建,确保向量数据库的持久化路径正确设置。