ChatGLM系列六:基于知识库的问答

在这里插入图片描述

1、安装milvus

下载milvus-standalone-docker-compose.yml并保存为docker-compose.yml

wget https://github.com/milvus-io/milvus/releases/download/v2.3.2/milvus-standalone-docker-compose.yml -O docker-compose.yml

运行milvus

sudo docker-compose up -d

2、文档预处理

import os
import re
import jieba
import torch
import pandas as pd
from pymilvus import utility
from pymilvus import connections, CollectionSchema, FieldSchema, Collection, DataType
from transformers import AutoTokenizer, AutoModel

connections.connect(
    alias="default",
    host='localhost',
    port='19530'
)

# 定义集合名称和维度
collection_name = "document"
dimension = 768
docs_folder = "./knowledge/"

tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model = AutoModel.from_pretrained("bert-base-chinese")


# 获取文本的向量
def get_vector(text):
    input_ids = tokenizer(text, padding=True, truncation=True, return_tensors="pt")["input_ids"]
    with torch.no_grad():
        output = model(input_ids)[0][:, 0, :].numpy()
    return output.tolist()[0]


def create_collection():
    # 定义集合字段
    fields = [
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True, description="primary id"),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=50),
        FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=10000),
        FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
    ]

    # 定义集合模式
    schema = CollectionSchema(fields=fields, description="collection schema")

    # 创建集合

    if utility.has_collection(collection_name):
    	# 如果你想继续添加新的文档可以直接 return。但你想要重新创建collection,就可以执行下面的代码
        # return
        utility.drop_collection(collection_name)
        collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)
        # 创建索引
        default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 2048}, "metric_type": "IP"}
        collection.create_index(field_name="vector", index_params=default_index)
        print(f"Collection {collection_name} created successfully")
    else:
        collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)
        # 创建索引
        default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 2048}, "metric_type": "IP"}
        collection.create_index(field_name="vector", index_params=default_index)
        print(f"Collection {collection_name} created successfully")


def init_knowledge():
    collection = Collection(collection_name)
    # 遍历指定目录下的所有文件,并导入到 Milvus 集合中
    docs = []
    for root, dirs, files in os.walk(docs_folder):
        for file in files:
            # 只处理以 .txt 结尾的文本文件
            if file.endswith(".txt"):
                file_path = os.path.join(root, file)
                with open(file_path, "r", encoding="utf-8") as f:
                    content = f.read()
                # 对文本进行清洗处理
                content = re.sub(r"\s+", " ", content)
                title = os.path.splitext(file)[0]
                # 分词
                words = jieba.lcut(content)
                # 将分词后的文本重新拼接成字符串
                content = " ".join(words)
                # 获取文本向量
                vector = get_vector(title + content)
                docs.append({"title": title, "content": content, "vector": vector})

    # 将文本内容和向量通过 DataFrame 一起导入集合中
    df = pd.DataFrame(docs)
    collection.insert(df)
    print("Documents inserted successfully")


if __name__ == "__main__":
    create_collection()
    init_knowledge()

3、知识库匹配

通过向量索引库计算出与问题最为相似的文档

import torch
from document_preprocess import get_vector
from pymilvus import Collection

collection = Collection("document")  # Get an existing collection.
collection.load()
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"


# 定义查询函数
def search_similar_text(input_text):
    # 将输入文本转换为向量
    input_vector = get_vector(input_text)
	# 查询前三个最匹配的向量ID
    similarity = collection.search(
        data=[input_vector],
        anns_field="vector",
        param={"metric_type": "IP", "params": {"nprobe": 10}, "offset": 0},
        limit=3,
        expr=None,
        consistency_level="Strong"
    )
    ids = similarity[0].ids
    # 通过ID查询出对应的知识库文档
    res = collection.query(
        expr=f"id in {ids}",
        offset=0,
        limit=3,
        output_fields=["id", "content", "title"],
        consistency_level="Strong"
    )
    print(res)
    return res


if __name__ == "__main__":
	question = input('Please enter your question: ')
    search_similar_text(question)

4、完成回答

from transformers import AutoModel, AutoTokenizer
from knowledge_query import search_similar_text


tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
model = model.eval()


def predict(input, max_length=2048, top_p=0.7, temperature=0.95, history=[]):
	res = search_similar_text(input)
	prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 "当前会话仅支持解决一个类型的问题,请清空历史信息重试",不允许在答案中添加编造成分,答案请使用中文。

已知内容:
{res}

问题:
{input}
"""
	query = prompt_template
	for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
	                                           temperature=temperature):
	    chatbot[-1] = (parse_text(input), parse_text(response))
	
	    yield chatbot, history

from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html

from knowledge_query import search_similar_text

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
model = model.eval()
is_knowledge = True

"""Override Chatbot.postprocess"""


def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>"+line
    text = "".join(lines)
    return text


def predict(input, chatbot, max_length, top_p, temperature, history):
    global is_knowledge

    chatbot.append((parse_text(input), ""))
    query = input
    if is_knowledge:
        res = search_similar_text(input)
        prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 "当前会话仅支持解决一个类型的问题,请清空历史信息重试",不允许在答案中添加编造成分,答案请使用中文。

已知内容:
{res}

问题:
{input}
"""
        query = prompt_template
        is_knowledge = False
    for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
                                               temperature=temperature):
        chatbot[-1] = (parse_text(input), parse_text(response))

        yield chatbot, history


def reset_user_input():
    return gr.update(value='')


def reset_state():
    global is_knowledge

    is_knowledge = False
    return [], []


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">ChatGLM</h1>""")

    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
                    container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

    history = gr.State([])

    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
                    show_progress=True)
    submitBtn.click(reset_user_input, [], [user_input])

    emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

demo.queue().launch(share=False, inbrowser=True)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值