【山东大学项目实训】9

向量数据库最终版修改好的代码

import os
import sys


import time
import tempfile
from dotenv import load_dotenv, find_dotenv
from embedding.call_embedding import get_embedding
from langchain.document_loaders import (
    UnstructuredFileLoader, UnstructuredMarkdownLoader,
    UnstructuredWordDocumentLoader, PyMuPDFLoader
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma

# 数据库文件存储路径
DEFAULT_DB_PATH = "../database/knowledge_db"

# 存储向量数据库生成的向量数据
DEFAULT_PERSIST_PATH = "../database/vector_data_base"


def get_files(dir_path):
    """
    递归获取指定目录下的所有文件路径。

    参数:
    - dir_path: 目标目录路径。

    返回:
    - file_list: 目录中所有文件的路径列表。
    """
    file_list = []
    for filepath, dirnames, filenames in os.walk(dir_path):
        for filename in filenames:
            file_list.append(os.path.join(filepath, filename))
    return file_list


def file_loader(file, loaders):
    """
    根据文件类型加载文档,并添加到加载器列表中。

    参数:
    - file: 文件路径或临时文件对象。
    - loaders: 文档加载器列表。
    """
    if isinstance(file, tempfile._TemporaryFileWrapper):
        file = file.name
    if not os.path.isfile(file):
        # 如果是目录,则递归加载目录下的所有文件
        [file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)]
        return

      #根据文件类型使用不同的方法获取文件
    file_type = file.split('.')[-1]
    if file_type == 'pdf':
        loaders.append(PyMuPDFLoader(file))
    elif file_type == 'md':
        loaders.append(UnstructuredMarkdownLoader(file))
    elif file_type == 'txt':
        loaders.append(UnstructuredFileLoader(file))
    elif file_type == 'docx':
        loaders.append(UnstructuredWordDocumentLoader(file))
    return


def create_db_info(files=DEFAULT_DB_PATH, embeddings="zhipuai", persist_directory=DEFAULT_PERSIST_PATH):
    """
    创建知识库,加载文档并生成嵌入向量数据库。

    参数:
    - files: 要处理的文件路径或文件路径列表。
    - embeddings: 用于生成嵌入向量的模型名称或对象。
    - persist_directory: 向量数据库的持久化目录路径。

    返回:
    - vectordb: 创建的向量数据库对象。
    """
    vectordb = create_db(files, persist_directory, embeddings)
    return vectordb


def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="zhipuai"):
    """
    加载文档,切分文本,生成嵌入向量,创建向量数据库。

    参数:
    - files: 要处理的文件路径或文件路径列表。
    - persist_directory: 向量数据库的持久化目录路径。
    - embeddings: 用于生成嵌入向量的模型名称或对象。

    返回:
    - vectordb: 创建的向量数据库对象。
    """
    if files is None:
        files = DEFAULT_DB_PATH
    if isinstance(files, str):
        files = [files]
    loaders = []
    [file_loader(file, loaders) for file in files]
    docs = []
    for loader in loaders:
        if loader is not None:
            try:
                docs.extend(loader.load())
            except Exception as e:
                print(f"Failed to load document: {e}")
    # 切分文档
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500, chunk_overlap=150)
    split_docs = text_splitter.split_documents(docs[:10])

    # 定义持久化路径
    if isinstance(persist_directory, str):
        persist_directory = [persist_directory]
    for directory in persist_directory:
        if not os.path.exists(directory):
            os.makedirs(directory)

    if isinstance(embeddings, str):
        embeddings = get_embedding(embedding=embeddings)

    # 加载数据库
    vectordb = Chroma.from_documents(
        documents=split_docs,
        embedding=embeddings,
        persist_directory=persist_directory
    )
    vectordb.persist()
    return vectordb


def presit_knowledge_db(vectordb):
    """
    将向量数据库持久化保存。

    参数:
    - vectordb: 要持久化的向量数据库对象。
    """
    vectordb.persist()


def load_knowledge_db(path, embeddings):
    """
    加载指定路径的向量数据库。

    参数:
    - path: 向量数据库的路径。
    - embeddings: 使用的嵌入向量模型。

    返回:
    - vectordb: 加载的向量数据库对象。
    """
    vectordb = Chroma(
        persist_directory=path,
        embedding_function=embeddings
    )
    return vectordb


if __name__ == "__main__":
    create_db(embeddings="zhipuai")
  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值