基于 ChatGLM3 和 LangChain 搭建知识库

简介

LangChain 是一个开源库,专为构建语言模型代理而设计,使开发者能够轻松集成和控制大型语言模型,如 OpenAI 的 GPT。LangChain 提供了一系列工具和框架,支持多种功能,包括但不限于聊天机器人、问答系统和自动化文本生成。

数据源地址

环境配置

在已完成 ChatGLM3 的部署基础上,还需要安装以下依赖包:

pip install langchain==0.0.292
pip install gradio==4.4.0
pip install chromadb==0.4.15
pip install sentence-transformers==2.2.2
pip install unstructured==0.10.30
pip install markdown==3.3.7

知识库搭建

我们选择NLTK作为语料库:

百度网盘链接:百度网盘 请输入提取码 下载完成将下载的nltk_data文件夹拷贝到用户跟目录

一,安装依赖:

virtualenv glm

source glm/bin/activate

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install tpu_perf-1.2.24-py3-none-manylinux2014_aarch64.whl

接着,为语料处理方便,我们将选用上述仓库中所有的 markdown、txt 文件作为示例语料库。注意,也可以选用其中的代码文件加入到知识库中,但需要针对代码文件格式进行额外处理。

二,向量转换代码

我们建立了一个 DocChatbot 类,旨在提供一个完整的框架,用于从文档中提取信息,建立索引,以及通过查询这些索引来回答用户问题。这种方法适用于需要从大量文档中快速检索信息的应用

类定义与构造器

class DocChatbot:
    _instance = None
    db_base_path = "data/db"

    def __init__(self) -> None:
        _instance = None
        self.llm = client
        self.vector_db = None
        self.string_db = None
        self.files = None
        self.embeddings_size = 768
        self.embeddings = HuggingFaceEmbeddings(model_name="./embedding")
        print("chatbot init success!")

构造函数初始化了嵌入模型、嵌入向量数据库、文本数据库和其他相关属性。这为文档处理和查询提供了基础。

文件处理方法

def get_files(self,dir_path: str) -> List[str]:
    file_list = []
    for filepath, dirnames, filenames in os.walk(dir_path):
        for filename in filenames:
            if filename.endswith(".md") or filename.endswith(".txt"):
                file_list.append(os.path.join(filepath, filename))
    return file_list

    我们定义的 get_files 方法用于遍历指定目录 dir_path 及其所有子目录,寻找并收集所有以 .md.txt 结尾的文件。它利用 os.walk 函数递归地访问每一个子目录,检查每个文件名是否符合指定的后缀名。如果符合,就使用 os.path.join 将目录路径和文件名组合成完整的文件路径,并将这个路径添加到 file_list 列表中。方法最终返回这个列表,包含了所有找到的符合条件的文件路径。这种方式使得能够系统地从复杂的文件系统中提取特定类型的文件,为后续的文件处理或数据加载提供基础。

文档到嵌入向量的转换

def docs2embedding(self, docs):
    emb = []
    for i in tqdm(range(len(docs) // 4)):
        emb += self.embeddings.embed_documents(docs[i * 4: i * 4 + 4])
    if len(docs) % 4 != 0:
        residue = docs[-(len(docs) % 4):] + [" " for _ in range(4 - len(docs) % 4)]
        emb += self.embeddings.embed_documents(residue)[:len(docs) % 4]
    return emb

  docs2embedding 方法将一组文档转换为嵌入向量,主要通过分批处理以优化效率和内存使用。首先,它将文档集每四个一组进行批处理,并使用 embed_documents 方法计算这些批次的嵌入向量,然后将这些向量追加到一个列表中。如果文档总数不是四的倍数,该方法将处理剩余的文档,通过添加空字符串以填满最后一批次到四个文档,并计算这些文档的嵌入向量。

查询和索引构建

1文本加载和预处理

在这一步,方法首先配置一个文本分割器以适应不同大小的文档,并为每种文档类型选择合适的加载器来处理文件。这确保了无论文档的格式如何,都能被正确加载并准备好进行下一步处理。

text_splitter = RecursiveCharacterTextSplitter(chunk_size=325, chunk_overlap=6, separators=["\n\n", "\n", "。", "!", ",", " ", ""])
docs = []
for file in file_list:
    ext_name = os.path.splitext(file)[-1]
    if ext_name == ".pptx":
        loader = UnstructuredPowerPointLoader(file)
    elif ext_name == ".docx":
        loader = UnstructuredWordDocumentLoader(file)
    elif ext_name == ".pdf":
        loader = UnstructuredPDFLoader(file)
    else:
        loader = UnstructuredFileLoader(file)
    doc = loader.load()
    doc[0].page_content = self.filter_space(doc[0].page_content)
    doc = text_splitter.split_documents(doc)
    docs.extend(doc)
2文档分割与向量数据库初始化

此阶段包括对加载的文档进行分割,并根据分割后的文档内容初始化或更新向量数据库。如果文档内容为空,则直接返回 False 表示初始化失败;否则,处理文档的嵌入向量,并将它们添加到向量数据库中。

init_vector_db_from_documents 方法在 ChatDoc 类中负责从提供的文件列表初始化或更新向量数据库。此方法首先利用特定的加载器根据文件扩展名加载文档,然后通过 RecursiveCharacterTextSplitter 分割文档内容,并通过 filter_space 清洗文本。分割和清洗后的文档被转换为嵌入向量,这些向量随后被存储到 faiss.IndexFlatL2 类型的向量数据库中。如果数据库尚未创建,则会新建一个;如果已存在,则向其中添加新的向量。这一过程支持多种文档格式,包括PPTX、DOCX和PDF等,旨在为搜索和查询功能提供高效、可扩展的数据支持,从而允许系统动态地处理和索引大量的文档,以便快速检索和问答。

def init_vector_db_from_documents(self, file_list: List[str]):
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=325, chunk_overlap=6,
                                                       separators=["\n\n", "\n", "。", "!", ",", " ", ""])
        docs = []
        for file in file_list:
            ext_name = os.path.splitext(file)[-1]
            if ext_name == ".pptx":
                loader = UnstructuredPowerPointLoader(file)
            elif ext_name == ".docx":
                loader = UnstructuredWordDocumentLoader(file)
            elif ext_name == ".pdf":
                loader = UnstructuredPDFLoader(file)
            else:
                loader = UnstructuredFileLoader(file)

            doc = loader.load()
            doc[0].page_content = self.filter_space(doc[0].page_content)
            doc = text_splitter.split_documents(doc)
            docs.extend(doc)

        # print([(len(x.page_content), count_chinese_chars(x.page_content)) for x in docs])
        # for item in docs:
        #     if len(item.page_content) / count_chinese_chars(item.page_content) > 1.5:
        #         print(len(item.page_content), item.page_content)

        # 文件解析失败
        if len(docs) == 0:
            return False

        if self.vector_db is None:
            self.files = ", ".join([item.split("/")[-1] for item in file_list])
            emb = self.docs2embedding([x.page_content for x in docs])
            self.vector_db = faiss.IndexFlatL2(self.embeddings_size)
            self.vector_db.add(np.array(emb))
            self.string_db = docs
        else:
            self.files = self.files + ", " + ", ".join([item.split("/")[-1] for item in file_list])
            emb = self.docs2embedding([x.page_content for x in docs])
            self.vector_db.add(np.array(emb))
            self.string_db += docs
        return True

保存和加载向量数据库

save_vector_db_to_local 方法

该方法用于将当前聊天机器人的向量数据库及相关数据保存到本地文件系统中。

def save_vector_db_to_local(self):
    # 获取当前时间并格式化为特定字符串格式,用于创建唯一文件夹名
    now = datetime.now()
    folder_name = now.strftime("%Y-%m-%d_%H-%M-%S-%f")
    os.mkdir(f"{self.db_base_path}/{folder_name}")  # 在指定的基路径下创建新文件夹

    # 保存向量索引文件
    faiss.write_index(self.vector_db, f"{self.db_base_path}/{folder_name}/db.index")

    # 将存储文本数据的string_db序列化并保存到文件
    byte_stream = pickle.dumps(self.string_db)
    with open(f"{self.db_base_path}/{folder_name}/db.string", "wb") as file:
        file.write(byte_stream)

    # 保存包含处理文件信息的文本文件
    with open(f"{self.db_base_path}/{folder_name}/name.txt", "w", encoding="utf-8") as file:
        file.write(self.files)
load_vector_db_from_local 方法

该方法用于从本地存储加载之前保存的向量数据库及其相关数据。

def load_vector_db_from_local(self, index_name: str):
    # 从文件中读取并反序列化文本数据
    with open(f"{self.db_base_path}/{index_name}/db.string", "rb") as file:
        byte_stream = file.read()
    self.string_db = pickle.loads(byte_stream)

    # 加载向量索引
    self.vector_db = faiss.read_index(f"{self.db_base_path}/{index_name}/db.index")

    # 读取包含处理文件信息的文本文件
    self.files = open(f"{self.db_base_path}/{index_name}/name.txt", 'r', encoding='utf-8').read()

整体代码:

# coding=utf-8

import os
import shutil
import time
import numpy as np
from datetime import datetime
import faiss
from langchain.document_loaders import UnstructuredPowerPointLoader, UnstructuredWordDocumentLoader, \
    UnstructuredPDFLoader, UnstructuredFileLoader
import logging
import pickle
from langchain.embeddings import HuggingFaceEmbeddings
# from embedding_tpu.embedding import Word2VecEmbedding
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List
from glob import glob
from tqdm import tqdm

from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role

client = get_client()

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)


class ChatDoc:
    _instance = None
    db_base_path = "data/db"

    def __init__(self) -> None:
        _instance = None
        self.llm = client
        self.vector_db = None
        self.string_db = None
        self.files = None

        # embeding_tpye = os.getenv("EMBEDDING_TYPE")
        # if embeding_tpye == "cpu":
        self.embeddings_size = 768
        self.embeddings = HuggingFaceEmbeddings(model_name="./embedding")
        # else:
        #     self.embeddings_size = 1024
        #     self.embeddings = Word2VecEmbedding()
        print("chatbot init success!")
        
    def get_files(self,dir_path: str) -> List[str]:
        file_list = []
        for filepath, dirnames, filenames in os.walk(dir_path):
            for filename in filenames:
                if filename.endswith(".md") or filename.endswith(".txt"):
                    file_list.append(os.path.join(filepath, filename))
        return file_list

    def docs2embedding(self, docs):
        emb = []
        for i in tqdm(range(len(docs) // 4)):
            emb += self.embeddings.embed_documents(docs[i * 4: i * 4 + 4])
        if len(docs) % 4 != 0:
            residue = docs[-(len(docs) % 4):] + [" " for _ in range(4 - len(docs) % 4)]
            emb += self.embeddings.embed_documents(residue)[:len(docs) % 4]

        return emb

    def query_from_doc(self, query_string, k=1):
        query_vec = self.embeddings.embed_query(query_string)
        _, i = self.vector_db.search(x=np.array([query_vec]), k=k)
        return [self.string_db[ind] for ind in i[0]]

    # split documents, generate embeddings and ingest to vector db
    def init_vector_db_from_documents(self, file_list: List[str]):
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=325, chunk_overlap=6,
                                                       separators=["\n\n", "\n", "。", "!", ",", " ", ""])
        docs = []
        for file in file_list:
            ext_name = os.path.splitext(file)[-1]
            if ext_name == ".pptx":
                loader = UnstructuredPowerPointLoader(file)
            elif ext_name == ".docx":
                loader = UnstructuredWordDocumentLoader(file)
            elif ext_name == ".pdf":
                loader = UnstructuredPDFLoader(file)
            else:
                loader = UnstructuredFileLoader(file)

            doc = loader.load()
            doc[0].page_content = self.filter_space(doc[0].page_content)
            doc = text_splitter.split_documents(doc)
            docs.extend(doc)

        # print([(len(x.page_content), count_chinese_chars(x.page_content)) for x in docs])
        # for item in docs:
        #     if len(item.page_content) / count_chinese_chars(item.page_content) > 1.5:
        #         print(len(item.page_content), item.page_content)

        # 文件解析失败
        if len(docs) == 0:
            return False

        if self.vector_db is None:
            self.files = ", ".join([item.split("/")[-1] for item in file_list])
            emb = self.docs2embedding([x.page_content for x in docs])
            self.vector_db = faiss.IndexFlatL2(self.embeddings_size)
            self.vector_db.add(np.array(emb))
            self.string_db = docs
        else:
            self.files = self.files + ", " + ", ".join([item.split("/")[-1] for item in file_list])
            emb = self.docs2embedding([x.page_content for x in docs])
            self.vector_db.add(np.array(emb))
            self.string_db += docs
        return True

    def load_vector_db_from_local(self, index_name: str):
        with open(f"{self.db_base_path}/{index_name}/db.string", "rb") as file:
            byte_stream = file.read()
        self.string_db = pickle.loads(byte_stream)
        self.vector_db = faiss.read_index(f"{self.db_base_path}/{index_name}/db.index")
        self.files = open(f"{self.db_base_path}/{index_name}/name.txt", 'r', encoding='utf-8').read()

    def save_vector_db_to_local(self):
        now = datetime.now()
        folder_name = now.strftime("%Y-%m-%d_%H-%M-%S-%f")
        os.mkdir(f"{self.db_base_path}/{folder_name}")
        faiss.write_index(self.vector_db, f"{self.db_base_path}/{folder_name}/db.index")
        byte_stream = pickle.dumps(self.string_db)
        with open(f"{self.db_base_path}/{folder_name}/db.string", "wb") as file:
            file.write(byte_stream)
        with open(f"{self.db_base_path}/{folder_name}/name.txt", "w", encoding="utf-8") as file:
            file.write(self.files)

    def del_vector_db(self, file_name):
        shutil.rmtree(f"{self.db_base_path}/" + file_name)
        self.vector_db = None

    def get_vector_db(self):
        file_list = glob(f"{self.db_base_path}/*")
        return [x.split("/")[-1] for x in file_list]

    def time2file_name(self, path):
        return open(f"{self.db_base_path}/{path}/name.txt", 'r', encoding='utf-8').read()

    def load_first_vector_db(self):
        file_list = glob(f"{self.db_base_path}/*")
        index_name = file_list[0].split("/")[-1]
        self.load_vector_db_from_local(index_name)

    def rename(self, file_list, new_name):
        with open(f"{self.db_base_path}/{file_list}/name.txt", "w", encoding="utf-8") as file:
            file.write(new_name)

    def stream_predict(self, query, history):
        history.append((query, ''))
        res = ''
        response = "根据文件内容,这是一份详尽的法律规定。"
        for i in response:
            res += i
            time.sleep(0.01)
            history[-1] = (query, res)
            yield res, history

    def filter_space(self, string):
        result = ""
        count = 0
        for char in string:
            if char == " " or char == '\t':
                count += 1
                if count < 4:
                    result += char
            else:
                result += char
                count = 0
        return result

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = DocChatbot()
        return cls._instance

  • 17
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值