【langchain 创建向量数据库非常完善的代码】

 - 支持faiss chroma两种数据库(faiss-cpu 支持旧数据库的合并)

- 支持避免重复文件embedding(hash)

- 支持众多文件格式

- 支持huggingface的embedding模型

- 优化了切分chunk的策略

- 支持多线程处理文件

修改自langchain-chatchat, 增加了一些功能, 优化了splitter的策略

import os
import glob
import hashlib
from typing import List
from functools import partial
from tqdm import tqdm

from multiprocessing import Pool
from langchain.document_loaders import (
    CSVLoader,
    EverNoteLoader,
    PDFMinerLoader,
    TextLoader,
    UnstructuredEmailLoader,
    UnstructuredEPubLoader,
    UnstructuredHTMLLoader,
    UnstructuredMarkdownLoader,
    UnstructuredODTLoader,
    UnstructuredPowerPointLoader,
    UnstructuredWordDocumentLoader,
    UnstructuredExcelLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.docstore.document import Document




CONFIG = {
    "doc_source": "./docs",  # 需要向量化的文档
    "embedding_model": "hugging_models/text2vec-base-chinese_config",  # embeding模型
    "db_source": "./db",  # 向量化数据库
    "db_type": "faiss",  #
    "chunk_size": 200,  # 块词量
    "chunk_overlap": 20,  # 交集范围
    "k": 3,  # 查询文档量
    "merge_rows": 5,  # 合并表格的行数,
    "hash_file_path": "hash_file.txt",  #
}


# 设置目录和embedding基础变量
source_directory = CONFIG["doc_source"]
embeddings_model_name = CONFIG["embedding_model"]
chunk_size = CONFIG["chunk_size"]
chunk_overlap = CONFIG["chunk_overlap"]
output_dir = CONFIG["db_source"]
k = CONFIG["k"]
merge_rows = CONFIG["merge_rows"]
hash_file_path = CONFIG["hash_file_path"]
db_type = CONFIG["db_type"]


# Custom document loaders 自定义文档加载
class MyElmLoader(UnstructuredEmailLoader):
    def load(self) -> List[Document]:
        """Wrapper adding fallback for elm without html"""
        try:
            try:
                doc = UnstructuredEmailLoader.load(self)
            except ValueError as e:
                if "text/html content not found in email" in str(e):
                    # Try plain text
                    self.unstructured_kwargs["content_source"] = "text/plain"
                    doc = UnstructuredEmailLoader.load(self)
                else:
                    raise
        except Exception as e:
            # Add file_path to exception message
            raise type(e)(f"{self.file_path}: {e}") from e

        return doc


# Map file extensions to document loaders and their arguments
# GBK2312 GB18030
LOADER_MAPPING = {
    ".csv": (CSVLoader, {}),
    ".doc": (UnstructuredWordDocumentLoader, {}),
    ".docx": (UnstructuredWordDocumentLoader, {}),
    ".enex": (EverNoteLoader, {}),
    ".eml": (MyElmLoader, {}),
    ".epub": (UnstructuredEPubLoader, {}),
    ".html": (UnstructuredHTMLLoader, {}),
    ".md": (UnstructuredMarkdownLoader, {}),
    ".odt": (UnstructuredODTLoader, {}),
    ".pdf": (PDFMinerLoader, {}),
    ".ppt": (UnstructuredPowerPointLoader, {}),
    ".pptx": (UnstructuredPowerPointLoader, {}),
    ".txt": (TextLoader, {"encoding": "utf8"}),
    ".xls": (UnstructuredExcelLoader, {}),
    ".xlsx": (UnstructuredExcelLoader, {}),
}


def read_hash_file(path="hash_file.txt"):
    hash_file_list = []
    if os.path.exists(path):
        with open(path, "r") as f:
            hash_file_list = [i.strip() for i in f.readlines()]
    return hash_file_list


def save_hash_file(hash_list, path="hash_file.txt"):
    with open(path, "w") as f:
        f.write("\n".join(hash_list))


def get_hash_from_file(path):
    with open(path, "rb") as f:
        readable_hash = hashlib.md5(f.read()).hexdigest()
    return readable_hash


def load_single_document(
    file_path: str, splitter: TextSplitter, merge_rows: int
) -> List[Document]:
    ext = "." + file_path.rsplit(".", 1)[-1]
    if ext in LOADER_MAPPING:
        loader_class, loader_args = LOADER_MAPPING[ext]
        loader = loader_class(file_path, **loader_args)
        docs = loader.load()

        # 针对不同的文件类型分别进行处理
        if not file_path.endswith((".xlsx", "xls", "csv")):
            # 合并一个文件中的所有page_content
            tmp = [i.page_content for i in docs]
            docs = Document(
                "".join(tmp).strip(),
                metadata={"source": docs[0].metadata["source"], "pages": len(tmp)},
            )
            # 进行split
            docs = splitter.split_documents([docs])
        else:
            # 表格数据,合并多个行
            merge_n = len(docs) // merge_rows + bool(len(docs) % merge_rows)
            _docs = []
            for i in range(merge_n):
                tmp = "\n\n".join(
                    [
                        d.page_content
                        for d in docs[i * merge_rows : (i + 1) * merge_rows]
                    ]
                )
                _docs.append(Document(tmp, metadata=dict(source=docs[0]["source"])))
            docs = _docs

        return docs

    raise ValueError(f"Unsupported file extension '{ext}'")


def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]:
    """
    Loads all documents from the source documents directory, ignoring specified files
    """
    all_files = []
    for ext in LOADER_MAPPING:
        all_files.extend(
            glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True)
        )
    filtered_files = [
        file_path for file_path in all_files if file_path not in ignored_files
    ]

    # hash filter
    if os.path.exists(hash_file_path):
        hash_file_list = read_hash_file(hash_file_path)
        if hash_file_list:
            tmp = []
            for file in filtered_files:
                hash = get_hash_from_file(file)
                if hash not in hash_file_list:
                    tmp.append(file)
                    hash_file_list(hash)
            filtered_files = tmp

            save_hash_file(hash_file_list, hash_file_path)

    # splitter
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap
    )
    load_document = partial(
        load_single_document, splitter=splitter, merge_rows=merge_rows
    )

    # load
    with Pool(processes=os.cpu_count()) as pool:
        results = []
        with tqdm(
            total=len(filtered_files), desc="Loading new documents", ncols=80
        ) as pbar:
            for i, docs in enumerate(
                pool.imap_unordered(load_document, filtered_files)
            ):
                results.extend(docs)
                pbar.update()

    return results


def process_documents(ignored_files: List[str] = []) -> List[Document]:
    """
    Load documents and split in chunks
    """
    print(f"Loading documents from {source_directory}")
    documents = load_documents(source_directory, ignored_files)
    if not documents:
        print("No new documents to load")
        exit(0)
    print(
        f"Loaded {len(documents)} new documents from {source_directory}."
        f"\nSplit into {len(documents)} chunks of text (max. {chunk_size} tokens each)"
    )
    return documents


def main():
    # Create embeddings
    # print(torch.cuda.is_available())
    # Create and store locally vectorstore
    print("Creating new vectorstore")
    documents = process_documents()
    print(f"Creating embeddings. May take some minutes...")
    embedding_function = SentenceTransformerEmbeddings(model_name=embeddings_model_name)

    if db_type == "chroma":
        from langchain.vectorstores import Chroma

        db = Chroma.from_documents(
            documents, embedding_function, persist_directory=output_dir
        )
        db.persist()
        db = None
    elif db_type == "faiss":
        from langchain.vectorstores import FAISS

        print("创建新数据库")
        db = FAISS.from_documents(documents, embedding_function)

        # 读取之前的db data(GPU版本的不支持)
        if os.path.exists(output_dir):
            try:
                print("读取旧数据库")
                old_db = FAISS.load_local(output_dir, embedding_function)

                print("融合新旧数据库")
                db.merge_from(old_db)
            except Exception as e:
                print(e)

        print("保存")
        db.save_local(output_dir)
        db = None
    else:
        raise NotImplementedError(f'未定义数据库 {db_type} 的实现.')

if __name__ == "__main__":
    main()
  • 9
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

放飞自我的Coder

你的鼓励很棒棒哦~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值