chatglm6b和闻达的功能扩展

最近大火的chatgpt,老板说让我看看能不能用自己的数据,回答专业一些,所以做了一些调研,最近用这个倒是成功推理了自己的数据,模型也开源了,之后有机会也训练一下自己的数据。

1.本机部署

1.1环境部署

1.1双击打开anconda prompt创建虚拟环境

Conda create –n chatglm python#(创建名叫chatglm的虚拟python环境)
Conda activate chatglm#(激活环境)

1.2下载pytorch(这里要根据自己的电脑版本下载)都在虚拟环境里操作

nvidia-smi#(查看自己cuda版本)
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118#(下载符合自己配置的torch,可以在官网https://pytorch.org/查看命令)

在这里插入图片描述
1.3在官网https://download.pytorch.org/whl/torch_stable.html下载对应的cuda版本的torch和torchvision,然后pip install即可
这时gpu版的torch就下载成功:,验证方法如图:
在这里插入图片描述
1.4安装依赖库

cd C:\Users\dz\Desktop\AIGC\wenda\wd-git\wenda\requirements#(进入工具包的simple目录下)
pip install –r .\requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install protobuf flatbuffers termcolor#(根据提示下载需要的包和自己的模型requirements.txt文件)

1.2 配置参数

  1. 配模型:下载对应的模型权重文件,放到model文件夹下面,这里我用的是RWKV:
    在这里插入图片描述
  2. 配数据:自己的文本数据放到txt文件夹下面:
    在这里插入图片描述

3.配环境:在environment里面把环境配成自己刚刚创建的虚拟环境
在这里插入图片描述

在config里面把权重文件的地址和配置改成自己的

在这里插入图片描述

1.3. 推理

  1. 双击step.2本地数据库建库.bat建本地数据库
    在这里插入图片描述
  2. 双击run_rwkv-点击运行.bat运行这个模型,然后浏览器打开http://127.0.0.1:17860/
    首先测试是否检测到本地数据库

问答功能

2.云服务器部署

电脑跑起来不行,所以在云服务器上搞了一个,本来是git源码的,但是源码git下来运行有问题,所以我还是把本地文件放到自己仓库,重新git了一下,云服务器租环境,就租wenda环境,然后

git clone https://github.com/Turing-dz/wenda_zoe_test.git

修改example.config.xml文件里的模型地址,然后就可以推理自己的数据了。

python pluges/gen_data_st.py#运行本地数据库
python wenda.py -t glm6b -p 6006#云上规定用6006映射

然后打开链接,打开知识库按钮,就会推理自己的数据文件了。

3.项目需求

3.1 修改前端的名字

修改views/static/string.js里面的常量值就可以。

3.2 不同用户用不同的知识库

这个其实是一个安全问题,但代码修改起来也很简单,分两步,一个是生成不同的知识库,下一步就是调用不同的知识库。

3.2.1修改生成不同目录的知识库文件

1.修改example.config.yml,当用户没有给-u参数时,默认txt下的文件生成到memory的default1文件夹下。

user_Type: default1

在这里插入图片描述
2.修改common.py文件,设置用户输入-u参数,如果没输入就用上面设置的默认default1

parser.add_argument('-u', type=str, dest="user_to_knowledge", help="不同用户的本地知识库")
user_Type = str(args.user_to_knowledge) 
if  user_Type != 'None':
    settings.user_Type=user_Type

在这里插入图片描述
在这里插入图片描述
3.修改gen_data_st.py文件,这个文件是生成知识库的,所以要修改生成地址

add_knowledge='memory/'+settings.user_Type
try:
    vectorstore_old = FAISS.load_local(
        add_knowledge, embeddings=embeddings)
    success_print("合并至已有索引。如不需合并请删除 add_knowledge 文件夹")
    vectorstore_old.merge_from(vectorstore)
    vectorstore_old.save_local(add_knowledge)

请添加图片描述
请添加图片描述

3.2.2 不同用户用不同知识库

修改zhishiku_rtst.py文件

def find(s,step = 0,memory_name=settings.user_Type): 

请添加图片描述

3.2.3效果

python '/root/autodl-fs/wenda_zoe_test/plugins/gen_data_st.py' -u u2
python '/root/autodl-fs/wenda_zoe_test/wenda.py' -u u2 -t glm6b -p 6006
python '/root/autodl-fs/wenda_zoe_test/plugins/gen_data_st.py' -u u5
python '/root/autodl-fs/wenda_zoe_test/wenda.py' -u u5 -t glm6b -p 6006

3.2.4一个txt或pdf自动生成一个独立的知识库

天哥需要一个文件生成一个知识库。这个就更简单了,修改gen_data_st.py文件,

#add_knowledge='memory/'+settings.user_Type#这个是上次的-u功能,可以先注释
#下面两段代码加到for循环里,并把地下的代码都右移一位,加到for循环里面
add_knowledge='memory/'+file
add_knowledge=add_knowledge.split(".")[0]

在这里插入图片描述
但在后面需要返回score最大文章的content时,发现了bug,上面改完之后每次生成下一个文件的知识库时都会把之前的包括了,所以如果数据要独立,还得在all_files的循环开始加上

docs=[]
vectorstore = None

最好把下面的合并索引也删掉。所以改完的gen_data_st .py如下:

import argparse
import sentence_transformers
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores.faiss import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
import threading
import pdfplumber
import re
import chardet
import os
import sys
import time
os.chdir(sys.path[0][:-8])
from common import success_print
from common import error_helper
from common import settings
from common import CounterLock
source_folder = 'txt'
source_folder_path = os.path.join(os.getcwd(), source_folder)
#add_knowledge='memory/'+settings.user_Type
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.ERROR)
root_path_list = source_folder_path.split(os.sep)
docs = []
vectorstore = None
model_path = settings.librarys.rtst.model_path
try:
    embeddings = HuggingFaceEmbeddings(model_name='')
    embeddings.client = sentence_transformers.SentenceTransformer(
        model_path, device="cuda")
except Exception as e:
    error_helper("embedding加载失败,请下载相应模型",
                 r"https://github.com/l15y/wenda#st%E6%A8%A1%E5%BC%8F")
    raise e
success_print("Embedding 加载完成")
embedding_lock=CounterLock()
vectorstore_lock=threading.Lock()
def clac_embedding(texts, embeddings, metadatas):
    global vectorstore
    with embedding_lock:
        vectorstore_new = FAISS.from_texts(texts, embeddings, metadatas=metadatas)
    with vectorstore_lock:
        if vectorstore is None:
            vectorstore = vectorstore_new
        else:
            vectorstore.merge_from(vectorstore_new)
def make_index():
    global docs
    if hasattr(settings.librarys.rtst,"size") and hasattr(settings.librarys.rtst,"overlap"):
        text_splitter = CharacterTextSplitter(
            chunk_size=int(settings.librarys.rtst.size), chunk_overlap=int(settings.librarys.rtst.overlap), separator='\n')
    else:
        text_splitter = CharacterTextSplitter(
            chunk_size=20, chunk_overlap=0, separator='\n')
    doc_texts = text_splitter.split_documents(docs)
    docs = []
    texts = [d.page_content for d in doc_texts]
    metadatas = [d.metadata for d in doc_texts]
    thread = threading.Thread(target=clac_embedding, args=(texts, embeddings, metadatas))
    thread.start()
    while embedding_lock.get_waiting_threads()>2:
        time.sleep(0.1)
all_files=[]
for root, dirs, files in os.walk(source_folder_path):
    for file in files:
        all_files.append([root, file])
success_print("文件列表生成完成",len(all_files))
for i in range(len(all_files)):
    root, file=all_files[i]
    length_of_read=0
    docs=[]
    vectorstore = None
    data = ""
    title = ""
    try:
        if file.endswith(".pdf"):
            file_path = os.path.join(root, file)
            with pdfplumber.open(file_path) as pdf:
                data_list = []
                for page in pdf.pages:
                    print(page.extract_text())
                    data_list.append(page.extract_text())
                data = "\n".join(data_list)
        else:
            # txt
            file_path = os.path.join(root, file)
            with open(file_path, 'rb') as f:
                b = f.read()
                result = chardet.detect(b)
            with open(file_path, 'r', encoding=result['encoding']) as f:
                data = f.read()
        add_knowledge='memory/'+file
        add_knowledge=add_knowledge.split(".")[0]
    except Exception as e:
        print("文件读取失败,当前文件已被跳过:",file,"。错误信息:",e)
    data = re.sub(r'!', "!\n", data)
    data = re.sub(r':', ":\n", data)
    data = re.sub(r'。', "。\n", data)
    data = re.sub(r'\r', "\n", data)
    data = re.sub(r'\n\n', "\n", data)
    data = re.sub(r"\n\s*\n", "\n", data)
    length_of_read+=len(data)
    docs.append(Document(page_content=data, metadata={"source": file}))
    if length_of_read > 1e5:
            success_print("处理进度",int(100*i/len(all_files)),f"%\t({i}/{len(all_files)})")
            make_index()
            # print(embedding_lock.get_waiting_threads())
            length_of_read=0
    if len(all_files) == 0:
        #error_print("txt 目录没有数据")
        print("txt 目录没有数据")
        sys.exit(0)
    if len(docs) > 0:
        make_index()
    while embedding_lock.get_waiting_threads()>0:
        time.sleep(0.1)
    with embedding_lock:
        time.sleep(0.1)
        with vectorstore_lock:
            success_print("处理完成")
    # try:
    #     vectorstore_old = FAISS.load_local(
    #         add_knowledge, embeddings=embeddings)
    #     success_print("合并至已有索引。如不需合并请删除 add_knowledge 文件夹")
    #     vectorstore_old.merge_from(vectorstore)
    #     vectorstore_old.save_local(add_knowledge)
    # except:
    # print("新建索引")
    vectorstore.save_local(add_knowledge)
    success_print("保存完成")

3.2.5返回score值最低的知识库prompt

需要遍历生成的知识库,所以在zhishiku_rtst.py里面加上

source_folder = 'memory'
memory_name_list=[]
source_folder_path = os.path.join(os.getcwd(), source_folder)
for root, dirs, files in os.walk(source_folder_path):
    for dir in dirs:
        memory_name_list.append(dir)

然后在find函数里遍历,并计算score值,score越大距离越远,所以要最小的prompt,所以zhishiku_rtst.py文件如下:

from langchain.vectorstores.faiss import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
import sentence_transformers
import numpy as np
import re,os
from plugins.common import settings,allowCROS
from plugins.common import error_helper 
from plugins.common import success_print 
divider='\n'
if not os.path.exists('memory'):
    os.mkdir('memory')
cunnrent_setting=settings.librarys.rtst
#print(cunnrent_setting.user_to_knowledge)
def get_doc_by_id(id,memory_name):
    return vectorstores[memory_name].docstore.search(vectorstores[memory_name].index_to_docstore_id[id])
def process_strings(A, C, B):
    # find the longest common suffix of A and prefix of B
    common = ""
    for i in range(1, min(len(A), len(B)) + 1):
        if A[-i:] == B[:i]:
            common = A[-i:]
    # if there is a common substring, replace one of them with C and concatenate
    if common:
        return A[:-len(common)] + C + B
    # otherwise, just return A + B
    else:
        return A + B
def get_doc(id,score,step,memory_name):
    doc = get_doc_by_id(id,memory_name)
    final_content=doc.page_content
    print("文段分数:",score,[doc.page_content])
        # print(id,score,step,memory_name,doc)
    if step > 0:
        for i in range(1, step+1):
            try:
                doc_before=get_doc_by_id(id-i,memory_name)
                if doc_before.metadata['source']==doc.metadata['source']:
                    final_content=process_strings(doc_before.page_content,divider,final_content)
                    # print("上文分数:",score,doc.page_content)
            except:
                pass
            try:
                doc_after=get_doc_by_id(id+i,memory_name)
                if doc_after.metadata['source']==doc.metadata['source']:
                    final_content=process_strings(final_content,divider,doc_after.page_content)
            except:
                pass
    if doc.metadata['source'].endswith(".pdf") or doc.metadata['source'].endswith(".txt"):
        title=f"[{doc.metadata['source']}](/api/read_news/{doc.metadata['source']})"
    else:
        title=doc.metadata['source']
    return {'title': title,'content':re.sub(r'\n+', "\n", final_content),"score":int(score)}
source_folder = 'memory'
memory_name_list=[]
source_folder_path = os.path.join(os.getcwd(), source_folder)
for root, dirs, files in os.walk(source_folder_path):
    for dir in dirs:
        memory_name_list.append(dir)
success_print(memory_name_list)
def find(s,step = 0,memory_name="test2"):  #"test2",
    try:
        scor_min=7000
        docs_min=[]
        for memory_name in memory_name_list:
            docs = []
            scor=0
            n=0
            embedding = get_vectorstore(memory_name).embedding_function(s)
            scores, indices = vectorstores[memory_name].index.search(np.array([embedding], dtype=np.float32), int(cunnrent_setting.count))
            #print("scores, indices:",scores, indices)
            for j, i in enumerate(indices[0]):
                if i == -1:continue
                if scores[0][j]>7000:continue
                docs.append(get_doc(i,scores[0][j],step,memory_name))
                scor+=scores[0][j]
                n+=1
            if n!=0:
                scor=scor/n
            if scor_min>scor:
                scor_min=scor
                docs_min=docs
        docs=docs_min
        #print(scor_min)
        print(docs)
        return docs
    except Exception as e:
        print(e)
        return []
try:
    embeddings = HuggingFaceEmbeddings(model_name='')
    embeddings.client = sentence_transformers.SentenceTransformer(cunnrent_setting.model_path,                                                                         device=cunnrent_setting.device)
except Exception  as e:
    error_helper("embedding加载失败,请下载相应模型",r"https://github.com/l15y/wenda#st%E6%A8%A1%E5%BC%8F")
    raise e
vectorstores={}
def get_vectorstore(memory_name):
    try:
        return vectorstores[memory_name]
    except Exception  as e:
        try:
            vectorstores[memory_name] = FAISS.load_local(
                'memory/'+memory_name, embeddings=embeddings)
            return vectorstores[memory_name]
        except Exception  as e:
            success_print("没有读取到RTST记忆区%s,将新建。"%memory_name)
    return None
from langchain.docstore.document import Document
from langchain.text_splitter import CharacterTextSplitter
from bottle import route, response, request, static_file, hook
import bottle
@route('/api/upload_rtst_zhishiku', method=("POST","OPTIONS"))
def upload_zhishiku():
    allowCROS()
    try:
        data = request.json
        title=data.get("title")
        memory_name=data.get("memory_name")
        data = re.sub(r'!', "!\n", data.get("txt"))
        data = re.sub(r'。', "。\n", data)
        data = re.sub(r'[\n\r]+', "\n", data)
        docs=[Document(page_content=data, metadata={"source":title })]
        print(docs)
        text_splitter = CharacterTextSplitter(
            chunk_size=20, chunk_overlap=0, separator='\n')
        doc_texts = text_splitter.split_documents(docs)
        texts = [d.page_content for d in doc_texts]
        metadatas = [d.metadata for d in doc_texts]
        vectorstore_new = FAISS.from_texts(texts, embeddings, metadatas=metadatas)
        vectorstore=get_vectorstore(memory_name)
        if vectorstore is None:
            vectorstores[memory_name]=vectorstore_new
        else:
            vectorstores[memory_name].merge_from(vectorstore_new)
        return '成功'
    except Exception as e:
        return str(e)
@route('/api/save_rtst_zhishiku', method=("POST","OPTIONS"))
def save_zhishiku():
    allowCROS()
    try:
        data = request.json
        memory_name=data.get("memory_name")
        vectorstores[memory_name].save_local('memory/'+memory_name)
        #print("保存到了"+'memory/'+memory_name)
        return "保存成功"
    except Exception as e:
        return str(e)
import json
@route('/api/find_rtst_in_memory', method=("POST","OPTIONS"))
def api_find():
    allowCROS()
    data = request.json
    prompt = data.get('prompt')
    step = data.get('step')
    memory_name=data.get("memory_name")
    if step is None:
        step = int(settings.library.step)
    # for i in range
    return json.dumps(find(prompt,int(step),memory_name_list))
@route('/api/save_news', method=("POST","OPTIONS"))
def save_news():
    allowCROS()
    try:
        data = request.json
        if not data:
            return 'no data'
        title = data.get('title')
        txt = data.get('txt')
        cut_file = f"txt/{title}.txt"
        with open(cut_file, 'w', encoding='utf-8') as f:
            f.write(txt)
            f.close()
        return 'success'
    except Exception as e:
        return(e)
@route('/api/read_news/:path', method=("GET","OPTIONS"))
def read_news(path=""):
    allowCROS()
    return static_file(path, root="txt/")

3.3 ptuning微调

3.3.1chatglm的ptuning

这里首先用官方的工具,生成对话的json数据,然后把autodl-tmp/ChatGLM-6B/ptuning/AdvertiseGen/里面的训练和测试的json数据替换成工具生成的自己的数据;修改autodl-tmp/ChatGLM-6B/ptuning/train.sh里面文件的地址,和数据的column,然后bash train.sh
在这里插入图片描述
训练完后可以运行web_demo.py文件测试效果。

3.3.2闻达的ptuning

我这里是将上面train完的autodl-tmp/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000文件复制到wenda的model/ptuning目录下。
在config.yml的里面glm6b下加入了

ptuning: "autodl-fs/wenda_zoe_test/model/ptuning"

plugins/common.py文件加入参数:

ptuning_addr='model/ptuning'
pre_seq_len=128
prefix_projection=False
if  ptuning_addr != 'None':
    settings.ptuning_addr=ptuning_addr
if  pre_seq_len != 'None':
    settings.pre_seq_len=pre_seq_len
if  prefix_projection is not True:
    settings.prefix_projection=prefix_projection

在plugins/llm_glm6b.py里面改掉模型的加载:

	#model = AutoModel.from_pretrained(settings.llm.path, local_files_only=True, trust_remote_code=True)
	config = AutoConfig.from_pretrained(settings.llm.path, trust_remote_code=True)
	config.pre_seq_len = settings.pre_seq_len
	config.prefix_projection = settings.prefix_projection
	tokenizer = AutoTokenizer.from_pretrained(settings.llm.path, local_files_only=True, trust_remote_code=True)
	if settings.ptuning_addr is not None:
        import torch
        model = AutoModel.from_pretrained(settings.llm.path, config=config,trust_remote_code=True)
        prefix_state_dict = torch.load(os.path.join(settings.ptuning_addr, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    else:
        model = AutoModel.from_pretrained(settings.llm.path, config=config,trust_remote_code=True)

然后再运行wenda.py测试自己做的数据集,就会看到ptuning效果。
在这里插入图片描述

3.4做socket接口

1.wenda_server.py

import logging
logging.captureWarnings(True)
import torch
import threading
import os
import json
import datetime
from bottle import route, response, request, static_file, hook
import bottle

from plugins.common import settings 
from plugins.common import error_helper,error_print,success_print
from plugins.common import CounterLock,allowCROS

#memory_name='test2'
def load_LLM():
    try:
        from importlib import import_module
        LLM = import_module('plugins.llm_'+settings.llm_type)
        return LLM
    except Exception as e:
        print("LLM模型加载失败,请阅读说明:https://github.com/l15y/wenda", e)
LLM = load_LLM()

logging=settings.loggings
if logging:
    from plugins.defineSQL import session_maker, 记录

if not hasattr(LLM,"Lock") :
    mutex = CounterLock()
else:
    mutex = LLM.Lock()


model = None
tokenizer = None


def load_model():
    with mutex:
        LLM.load_model()
    torch.cuda.empty_cache()
    success_print("模型加载完成")


thread_load_model = threading.Thread(target=load_model)
thread_load_model.start()
zhishiku = None

def load_zsk():
    try:
        from importlib import import_module
        global zhishiku
        import plugins.zhishiku as zsk
        zhishiku= zsk
        success_print("知识库加载完成")
    except Exception as e:
        error_helper("知识库加载失败,请阅读说明",r"https://github.com/l15y/wenda#%E7%9F%A5%E8%AF%86%E5%BA%93")
        raise e
    
thread_load_zsk = threading.Thread(target=load_zsk)
thread_load_zsk.start()
import re
footer = ''
from socket import *
IP = '127.0.0.1'
PORT = 50000
BUFLEN = 512
listenSocket = socket(AF_INET, SOCK_STREAM)
listenSocket.bind((IP, PORT))
listenSocket.listen(8)
print(f'服务端启动成功,在{PORT}端口等待客户端连接...')
dataSocket, addr = listenSocket.accept()
print('接受一个客户端连接:', addr)
while True:
    # response.content_type = "text/event-stream"
    # response.add_header("Connection", "keep-alive")
    # response.add_header("Cache-Control", "no-cache")
    max_length = None
    if max_length is None:
        max_length = 2048
    top_p = None
    if top_p is None:
        top_p = 0.2
    temperature = None
    if temperature is None:
        temperature = 0.8
    use_zhishiku = None
    if use_zhishiku is None:
        use_zhishiku = False
    recved = dataSocket.recv(BUFLEN)
    if not recved:
        break
    prompt = recved.decode()
    keyword=None
    if keyword is None:
        keyword = prompt
    history_formatted = None
    response_text = ''
    IP = request.environ.get(
        'HTTP_X_REAL_IP') or request.environ.get('REMOTE_ADDR')
    error = ""
    if use_zhishiku:
        
        response_d = zhishiku.find(keyword,int(settings.library.step))
        output_sources = [i['title'] for i in response_d]
        results = '\n'.join([str(i+1)+". "+re.sub('\n\n', '\n', response_d[i]['content']) for i in range(len(response_d))])
        prompt = 'system: 请扮演一名专业分析师,根据以下内容回答问题:'+prompt + "\n"+ results
        if settings.library.show_soucre == True:
            footer = "\n### 来源:\n"+('\n').join(output_sources)
    with mutex:
        try:
            for response in LLM.chat_one(prompt, history_formatted, max_length, top_p, temperature, zhishiku=use_zhishiku):
                if (response):
                    response= response+footer
        except Exception as e:
            error = str(e)
            error_print("错误", error)
            response = ''
            # raise e
        torch.cuda.empty_cache()
    if response == '':
        response= "发生错误,正在重新加载模型"+error
        os._exit(0)
    if logging:
        with session_maker() as session:
            jl = 记录(时间=datetime.datetime.now(), IP=IP,=prompt,=response)
            session.add(jl)
            session.commit()
    print(response)
    dataSocket.send(f'服务端返回信息: {response}'.encode())
    # yield "/././"
dataSocket.close()
listenSocket.close()

# import webbrowser
# webbrowser.open_new('http://127.0.0.1:'+str(settings.Port))

# import functools
# def pathinfo_adjust_wrapper(func):
#     # A wrapper for _handle() method
#     @functools.wraps(func)
#     def _(s,environ):
#         environ["PATH_INFO"] = environ["PATH_INFO"].encode("utf8").decode("latin1")
#         return func(s,environ)
#     return _
# bottle.Bottle._handle = pathinfo_adjust_wrapper(bottle.Bottle._handle)#修复bottle在处理utf8 url时的bug

# bottle.run(server='paste', host="0.0.0.0", port=settings.port, quiet=True)

2.client.py

from socket import *

IP = '127.0.0.1'
SERVER_PORT = 50000
BUFLEN = 1024

# 实例化一个socket对象,指明协议
dataSocket = socket(AF_INET, SOCK_STREAM)

# 连接服务端socket
dataSocket.connect((IP, SERVER_PORT))

while True:
    # 从终端读入用户输入的字符串
    toSend = input('>>> ')
    if  toSend =='exit':
        break
    # 发送消息,也要编码为 bytes
    dataSocket.send(toSend.encode())

    # 等待接收服务端的消息
    recved = dataSocket.recv(BUFLEN)
    # 如果返回空bytes,表示对方关闭了连接
    if not recved:
        break
    # 打印读取的信息
    print(recved.decode())

dataSocket.close()

3.5用langchain与sql数据库交互

用本地glm6b想解决数据库问题,因此就结合langchain来做,因为langchain和glm6b的适配问题,因此对langchain做了一点点处理,如下:
首先运行文件如下,将glm6b引入llm

from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel, AutoConfig
import sys
from typing import List,  Optional


class ChatGLM(LLM):
    max_token: int = 2048
    temperature: float = 0.8
    top_p = 0.1
    tokenizer: object = None
    model: object = None
    history_len: int = 1024

    def __init__(self):
        super().__init__()

    @property
    def _llm_type(self) -> str:
        return "GLM"

    def load_model(self, llm_device="gpu", model_name_or_path=None):
        model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config,
                                               trust_remote_code=True).half().cuda()

    def _call(self, prompt: str, history: List[str] = [], stop: Optional[List[str]] = None):
        response, _ = self.model.chat(
            self.tokenizer, prompt,
            # history=history[-self.history_len:] if self.history_len > 0 else [],
            max_length=self.max_token, temperature=self.temperature,
            top_p=self.top_p)
        return response


modelpath = r"C:\xxx\Desktop\wenda-main\wenda-main\model\chatglm2-6b"
sys.path.append(modelpath)
print(modelpath)
llm = ChatGLM()
llm.load_model(model_name_or_path=modelpath)
from langchain import SQLDatabase, SQLDatabaseChain
db = SQLDatabase.from_uri("mysql+pymysql://root:xxx@localhost/xunlian")
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
db_chain.run("司明苏的单个人员成绩是多少?")

因为适配问题,所以需要对langchain/chain/sql_database/base.py文件里面的两次输出进行处理

class SQLDatabaseChain(Chain):
	def _call:
		if sql_cmd:
   	 		sql_cmd = sql_cmd.split(";")
    		sql_cmd = sql_cmd[0]+";"
    	if chain_result['result']:
            chain_result['result']=chain_result['result'].split("\n")[-1]

处理完后的base文件如下:

"""Chain for interacting with SQL Database."""
from __future__ import annotations

import warnings
from typing import Any, Dict, List, Optional

from pydantic import Extra, Field, root_validator

from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.sql_database import SQLDatabase
from langchain.tools.sql_database.prompt import QUERY_CHECKER

INTERMEDIATE_STEPS_KEY = "intermediate_steps"


class SQLDatabaseChain(Chain):
    """Chain for interacting with SQL Database.

    Example:
        .. code-block:: python

            from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
            db = SQLDatabase(...)
            db_chain = SQLDatabaseChain.from_llm(OpenAI(), db)
    """

    llm_chain: LLMChain
    llm: Optional[BaseLanguageModel] = None
    """[Deprecated] LLM wrapper to use."""
    database: SQLDatabase = Field(exclude=True)
    """SQL Database to connect to."""
    prompt: Optional[BasePromptTemplate] = None
    """[Deprecated] Prompt to use to translate natural language to SQL."""
    top_k: int = 5
    """Number of results to return from the query"""
    input_key: str = "query"  #: :meta private:
    output_key: str = "result"  #: :meta private:
    return_intermediate_steps: bool = False
    """Whether or not to return the intermediate steps along with the final answer."""
    return_direct: bool = False
    """Whether or not to return the result of querying the SQL table directly."""
    use_query_checker: bool = False
    """Whether or not the query checker tool should be used to attempt 
    to fix the initial SQL from the LLM."""
    query_checker_prompt: Optional[BasePromptTemplate] = None
    """The prompt template that should be used by the query checker"""

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    @root_validator(pre=True)
    def raise_deprecation(cls, values: Dict) -> Dict:
        if "llm" in values:
            warnings.warn(
                "Directly instantiating an SQLDatabaseChain with an llm is deprecated. "
                "Please instantiate with llm_chain argument or using the from_llm "
                "class method."
            )
            if "llm_chain" not in values and values["llm"] is not None:
                database = values["database"]
                prompt = values.get("prompt") or SQL_PROMPTS.get(
                    database.dialect, PROMPT
                )
                values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
        return values

    @property
    def input_keys(self) -> List[str]:
        """Return the singular input key.

        :meta private:
        """
        return [self.input_key]

    @property
    def output_keys(self) -> List[str]:
        """Return the singular output key.

        :meta private:
        """
        if not self.return_intermediate_steps:
            return [self.output_key]
        else:
            return [self.output_key, INTERMEDIATE_STEPS_KEY]

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        input_text = f"{inputs[self.input_key]}\nSQLQuery:"
        _run_manager.on_text(input_text, verbose=self.verbose)
        # If not present, then defaults to None which is all tables.
        table_names_to_use = inputs.get("table_names_to_use")
        table_info = self.database.get_table_info(table_names=table_names_to_use)
        llm_inputs = {
            "input": input_text,
            "top_k": str(self.top_k),
            "dialect": self.database.dialect,
            "table_info": table_info,
            "stop": ["\nSQLResult:"],
        }
        intermediate_steps: List = []
        try:

            intermediate_steps.append(llm_inputs)  # input: sql generation
            sql_cmd = self.llm_chain.predict(
                callbacks=_run_manager.get_child(),
                **llm_inputs,
            ).strip()
            if not self.use_query_checker:
                if sql_cmd:
                    sql_cmd = sql_cmd.split(";")
                    sql_cmd = sql_cmd[0]+";"
                _run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)

                intermediate_steps.append(
                    sql_cmd
                )  # output: sql generation (no checker)
                intermediate_steps.append({"sql_cmd": sql_cmd})  # input: sql exec
                result = self.database.run(sql_cmd)
                if result:
                    my_sec_prompt="Question: "+input_text+sql_cmd+"\n"+"SQLResult:"+result+"\n"+"Answer:"
                intermediate_steps.append(str(result))  # output: sql execs
            else:
                query_checker_prompt = self.query_checker_prompt or PromptTemplate(
                    template=QUERY_CHECKER, input_variables=["query", "dialect"]
                )
                query_checker_chain = LLMChain(
                    llm=self.llm_chain.llm, prompt=query_checker_prompt
                )
                query_checker_inputs = {
                    "query": sql_cmd,
                    "dialect": self.database.dialect,
                }
                checked_sql_command: str = query_checker_chain.predict(
                    callbacks=_run_manager.get_child(), **query_checker_inputs
                ).strip()
                intermediate_steps.append(
                    checked_sql_command
                )  # output: sql generation (checker)
                _run_manager.on_text(
                    checked_sql_command, color="green", verbose=self.verbose
                )
                intermediate_steps.append(
                    {"sql_cmd": checked_sql_command}
                )  # input: sql exec
                result = self.database.run(checked_sql_command)
                intermediate_steps.append(str(result))  # output: sql exec
                sql_cmd = checked_sql_command

            _run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
            _run_manager.on_text(result, color="yellow", verbose=self.verbose)
            if self.return_direct:
                final_result = result
            else:
                _run_manager.on_text("\nAnswer:", verbose=self.verbose)
                input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
                llm_inputs["input"] = input_text
                intermediate_steps.append(llm_inputs)  # input: final answer
                final_result = self.llm_chain.predict(
                    callbacks=_run_manager.get_child(),
                    **llm_inputs,
                ).strip()
                final_result=final_result.split("\n")[-1]
                intermediate_steps.append(final_result)  # output: final answer
                _run_manager.on_text(final_result, color="green", verbose=self.verbose)
            chain_result: Dict[str, Any] = {self.output_key: final_result}
            if self.return_intermediate_steps:
                chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
            # chain_result=chain_result.split(";")
            if chain_result['result']:
                chain_result['result']=chain_result['result'].split("\n")[-1]
            return chain_result
        except Exception as exc:
            # Append intermediate steps to exception, to aid in logging and later
            # improvement of few shot prompt seeds
            exc.intermediate_steps = intermediate_steps  # type: ignore
            raise exc

    @property
    def _chain_type(self) -> str:
        return "sql_database_chain"

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        db: SQLDatabase,
        prompt: Optional[BasePromptTemplate] = None,
        **kwargs: Any,
    ) -> SQLDatabaseChain:
        prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT)
        llm_chain = LLMChain(llm=llm, prompt=prompt)
        return cls(llm_chain=llm_chain, database=db, **kwargs)


class SQLDatabaseSequentialChain(Chain):
    """Chain for querying SQL database that is a sequential chain.

    The chain is as follows:
    1. Based on the query, determine which tables to use.
    2. Based on those tables, call the normal SQL database chain.

    This is useful in cases where the number of tables in the database is large.
    """

    decider_chain: LLMChain
    sql_chain: SQLDatabaseChain
    input_key: str = "query"  #: :meta private:
    output_key: str = "result"  #: :meta private:
    return_intermediate_steps: bool = False

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        database: SQLDatabase,
        query_prompt: BasePromptTemplate = PROMPT,
        decider_prompt: BasePromptTemplate = DECIDER_PROMPT,
        **kwargs: Any,
    ) -> SQLDatabaseSequentialChain:
        """Load the necessary chains."""
        sql_chain = SQLDatabaseChain.from_llm(
            llm, database, prompt=query_prompt, **kwargs
        )
        decider_chain = LLMChain(
            llm=llm, prompt=decider_prompt, output_key="table_names"
        )
        return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)

    @property
    def input_keys(self) -> List[str]:
        """Return the singular input key.

        :meta private:
        """
        return [self.input_key]

    @property
    def output_keys(self) -> List[str]:
        """Return the singular output key.

        :meta private:
        """
        if not self.return_intermediate_steps:
            return [self.output_key]
        else:
            return [self.output_key, INTERMEDIATE_STEPS_KEY]

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        _table_names = self.sql_chain.database.get_usable_table_names()
        table_names = ", ".join(_table_names)
        llm_inputs = {
            "query": inputs[self.input_key],
            "table_names": table_names,
        }
        _lowercased_table_names = [name.lower() for name in _table_names]
        table_names_from_chain = self.decider_chain.predict_and_parse(**llm_inputs)
        table_names_to_use = [
            name
            for name in table_names_from_chain
            if name.lower() in _lowercased_table_names
        ]
        _run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)
        _run_manager.on_text(
            str(table_names_to_use), color="yellow", verbose=self.verbose
        )
        new_inputs = {
            self.sql_chain.input_key: inputs[self.input_key],
            "table_names_to_use": table_names_to_use,
        }
        return self.sql_chain(
            new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True
        )

    @property
    def _chain_type(self) -> str:
        return "sql_database_sequential_chain"

效果如下:
在这里插入图片描述

3.6自定义template与sql数据库交互并用flask前端展示

#1.自定义一个生成sql语句的template,这里需要传入table——info和question
from langchain.prompts import PromptTemplate
from langchain import SQLDatabase, SQLDatabaseChain
db = SQLDatabase.from_uri("mysql+pymysql://root:xxx@localhost/xunlian")
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
table_info=db_chain.database.get_table_info(table_names=None)
prompt1 = PromptTemplate.from_template("你是一个SQL专家,现在给你提供一个数据库表单的提示信息:{table_info}\n请根据上述#数据库表单的提示信息,"
                                      "针对{qusetion},创建一个语法正确的MySQL查询语句,使用LIMIT子句查询最多3个结果,必须将查询语句中的字段使用反引号(`)包括起来,"
                                      "必须使用数据库表单的提示信息中可见的列名创建MySQL查询语句,查询语句不能出现不存在的列名。请按以下输出示例进行输出:")
#2.sql查询
#3.将问题,sql语句,查询结果给模型,规定它输出人话
prompt2=PromptTemplate.from_template("数据库查询语句是:{first_format}\n数据库查询结果是:{second_format}\n请根据上述查询过程,,回答的内容必须简单明了,必须在30个字以内:{question}")
import pymysql
conn = pymysql.connect(host='localhost', user='root', password='xxx', database='xunlian')
cursor = conn.cursor()  
#4.将上面的步骤封装到一个函数
conn = pymysql.connect(host='localhost', user='root', password='xxx', database='xunlian')
cursor = conn.cursor()      
table_info=db_chain.database.get_table_info(table_names=None)    
prompt1 = PromptTemplate.from_template("你是一个SQL专家,现在给你提供一个数据库表单的提示信息:{table_info}\n请根据上述#数据库表单的提示信息,"
                                      "针对{qusetion},创建一个语法正确的MySQL查询语句,使用LIMIT子句查询最多3个结果,必须将查询语句中的字段使用反引号(`)包括起来,"
                                      "必须使用数据库表单的提示信息中可见的列名创建MySQL查询语句,查询语句不能出现不存在的列名。请按以下输出示例进行输出:")
# print(prompt1.format(table_info=table_info,qusetion="司明苏的单个人员成绩是多少?"))
prompt2=PromptTemplate.from_template("数据库查询语句是:{first_format}\n数据库查询结果是:{second_format}\n请根据上述查询过程,,回答的内容必须简单明了,必须在30个字以内:{question}")
def my_out(question):
    first=llm.predict(prompt1.format(table_info=table_info,qusetion=question))
    first_format=first.split("```")[1][len('sql'):].lstrip()
    
    cursor.execute(first_format)
    second_format= cursor.fetchall()
    
    third_format=llm.predict(prompt2.format(first_format=first_format, second_format=second_format,question=question))
    return third_format
# print(my_out("单个人员成绩在70分以上的姓名有谁?"))
while True:
    question = input("请输入一个名词:\n")
    if question == "结束":
        break
    else:
        print(my_out(question))
        continue                      

用flask进行交互

#1.index.html
<!DOCTYPE html>
<html>
<head>
    <title>一问一答</title>
    <style>
        body {
            font-family: Arial, sans-serif;
            text-align: center;
        }

        h1 {
            color: #0080FF;
        }

        .container {
            margin: 50px auto;
            max-width: 400px;
            padding: 20px;
            border: 1px solid #ccc;
            border-radius: 10px;
        }

        .question-input {
            width: 100%;
            padding: 10px;
            margin-bottom: 10px;
            border: 1px solid #ccc;
            border-radius: 5px;
        }

        .submit-btn {
            background-color: #0080FF;
            color: #fff;
            border: none;
            padding: 10px 20px;
            border-radius: 5px;
            cursor: pointer;
        }

        .submit-btn:hover {
            background-color: #005eff;
        }

        .answer {
            margin-top: 20px;
            font-weight: bold;
        }
    </style>
</head>
<body>
    <div class="container">
        <h1>一问一答</h1>
        <p>请输入您的问题:</p>
        <input type="text" id="question" class="question-input">
        <button onclick="submitQuestion()" class="submit-btn">提交</button>
        <p class="answer">回答:</p>
        <p id="answer" class="answer"></p>
    </div>

    <script>
        function submitQuestion() {
            // 获取用户输入的问题
            var question = document.getElementById('question').value;

            // 创建一个FormData对象,用于将数据添加到POST请求中
            var formData = new FormData();
            formData.append('question', question);

            // 发送POST请求
            fetch('/get_answer', {
                method: 'POST',
                body: formData
            })
            .then(response => response.text())
            .then(answer => {
                // 显示回答
                document.getElementById('answer').innerText = answer;
            })
            .catch(error => {
                console.error('Error:', error);
            });
        }
    </script>
</body>
</html>

#2.app.py
from flask import Flask, render_template, request
from main_testdb import my_out
app = Flask(__name__)


@app.route('/')
def index():
    return render_template('index.html')


@app.route('/get_answer', methods=['POST'])
def get_answer():
    # 获取前端传递的问题
    question = request.form['question']

    # 在这里你可以处理问题并返回相应的答案
    # 假设你的问题回答逻辑与之前的JavaScript示例相同
    if question == '你叫什么名字?':
        answer = '我叫xxx,很高兴为您服务!'
    elif question == '你会说中文吗?':
        answer = '是的,我会说中文,还会说多种其他语言。'
    else:
        answer = my_out(question)
    # 返回答案到前端
    return answer


if __name__ == '__main__':
    app.run(debug=True)

在这里插入图片描述

3.7自定义自己的tool和chain

#1.模型
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel, AutoConfig
import sys
from typing import List,  Optional

class ChatGLM(LLM):
    max_token: int = 8192
    temperature: float = 0.1
    top_p = 0.1
    tokenizer: object = None
    model: object = None
    history_len: int = 0

    def __init__(self):
        super().__init__()

    @property
    def _llm_type(self) -> str:
        return "GLM"

    def load_model(self, llm_device="gpu", model_name_or_path=None):
        model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config,
                                               trust_remote_code=True).half().cuda()

    def _call(self, prompt: str, history: List[str] = [], stop: Optional[List[str]] = None):
        response, _ = self.model.chat(
            self.tokenizer, prompt,
            max_length=self.max_token, temperature=self.temperature,
            top_p=self.top_p)
        return response

modelpath = r"C:\Users\robot\Desktop\wenda-main\wenda-main\model\chatglm2-6b"
sys.path.append(modelpath)
# print(modelpath)
llm = ChatGLM()
llm.load_model(model_name_or_path=modelpath)
#2.tools
import cv2
def catch_video(name='my_video', video_index=0):
    cap = cv2.VideoCapture(video_index) # 创建摄像头识别类
    if not cap.isOpened():
        raise Exception('Check if the camera is on.')
    while cap.isOpened():
        catch, frame = cap.read()  # 读取每一帧图片
        cv2.imshow(name, frame) # 在window上显示图片
        key = cv2.waitKey(10)
        if key & 0xFF == ord('q'):
            # 按q退出
            break
        if cv2.getWindowProperty(name, cv2.WND_PROP_AUTOSIZE) < 1:
            # 点x退出
            break
    # 释放摄像头
    cap.release()
    cv2.destroyAllWindows()
from langchain import SQLDatabase, SQLDatabaseChain
db = SQLDatabase.from_uri("mysql+pymysql://root:w19891207@localhost/xunlian")
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
import pymysql
conn = pymysql.connect(host='localhost', user='root', password='w19891207', database='xunlian')
cursor = conn.cursor()
from langchain.prompts import PromptTemplate
table_info="{"+db_chain.database.get_table_info(table_names=None)+"\n}"
table_info = table_info.split('\n')
table_info = ['#' + line for line in table_info]
table_info= '\n'.join(table_info)
# print(table_info)
prompt_sql = PromptTemplate.from_template("""你是一个SQL专家,现在给你提供一个数据库表单的提示信息。\n#数据库表单的提示信息包括:
{table_info}
请根据上述#数据库表单的提示信息,针对"用户问题",创建一个语法正确的MySQL查询语句,使用LIMIT子句查询最多3个结果,必须将查询语句中的字段使用反引号(`)包括起来,必须使用"数据库表单的提示信息"中可见的列名创建MySQL查询语句,查询语句不能出现不存在的列名。请按以下#输出示例进行输出:

#输出示例
#用户问题:张五的岗位是什么?
#MySQL查询语句:SELECT 岗位 FROM 总表 WHERE 姓名 = `张五`;

现在我们开始:
用户问题:{question}
MySQL查询语句:""")
prompt2=PromptTemplate.from_template("""数据库查询语句是:{first_format}\n数据库查询结果是:{second_format}\n请根据上述查询过程进行回答,回答的内容必须简单明了,必须在30个字以内:{question}""")
def my_out(question):
    # print(table_info)
    first=llm.predict(prompt_sql.format(table_info=table_info, question=question))
    cursor.execute(first)
    second_format= cursor.fetchall()
    third_format=llm.predict(prompt2.format(first_format=first, second_format=second_format,question=question))
    return third_format



from langchain.prompts import PromptTemplate
prompt1=PromptTemplate.from_template("""Chatbot
单选题,以下哪个工具可以完成问题或任务
#{question}()#
A:《数据库》工具,该工具用于访问训练数据库,对人员姓名、年龄训练成绩进行查询。
B:《作图》工具,绘图。
C:《监控视频》工具,该工具用于调用指定地点的摄像头、监控、视频、态势进行查看。
D:《备选》工具,该工具用于对措施、计划、管理、评估等定性问题进行回答,或其他不适合任何工具的情况完成任务或问题。""")
def use_mytools(question):
    choice = llm.predict(prompt1.format(question=question))
    print(choice)
    if "C" in choice:
        catch_video(question)
    elif "A" in choice:
        print(my_out(question))
    else:
        print(llm.predict(question))
# use_mytools("张国立的年龄是多少?")
use_mytools("调用摄像头")
  • 5
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我是小z呀

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值