话不多说,有了前面的基础我们直接给带有注释的代码吧
构建检索问答链
from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
from langchain.vectorstores.chroma import Chroma
from dotenv import find_dotenv, load_dotenv
import os
from langchain_community.llms import QianfanLLMEndpoint
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
# 首先加载一下embedding和向量数据库
embedding = QianfanEmbeddingsEndpoint()
persist_directory = './vector_db_test'
vectordb = Chroma(
persist_directory=persist_directory, # 允许我们将persist_directory目录保存到磁盘上
embedding_function=embedding
)
print(f"向量库中存储的数量:{vectordb._collection.count()}")
# 加载一下模型
load_dotenv(find_dotenv())
QIANFAN_AK = os.environ["QIANFAN_AK"]
QIANFAN_SK = os.environ["QIANFAN_SK"]
llm = QianfanLLMEndpoint(streaming=True)
# 制作一下prompt模板
template = """使用以下上下文来回答最后的问题。如果你不知道答案,就说你不知道,不要试图编造答
案。最多使用三句话。尽量使答案简明扼要。总是在回答的最后说“谢谢你的提问!”。
{context}
问题: {question}
"""
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context","question"],template=template)
# 形成链
qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectordb.as_retriever(),return_source_documents=True,chain_type_kwargs={"prompt":QA_CHAIN_PROMPT})
#问题
question_1 = "什么是南瓜书?"
question_2 = "Prompt Engineering for Developer是谁写的?"
# 基于召回结果和 query 结合起来构建的 prompt 效果
result = qa_chain({"query": question_1})
print("大模型+知识库后回答 question_1 的结果:", result["result"])
# 大模型自己回答的效果
prompt_template = """请回答下列问题:{}""".format(question_1)
out = llm.predict(prompt_template)
print('大模型自己回答的效果_q1:',out)
利用streamlit构建小demo
import streamlit as st
from langchain_community.llms import QianfanLLMEndpoint
import os
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import sys
sys.path.append("./docs/C3") # 将父目录放入系统路径中
from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
from langchain.vectorstores.chroma import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file
QIANFAN_AK = os.environ["QIANFAN_AK"]
QIANFAN_SK = os.environ["QIANFAN_SK"]
def generate_response(input_text):
llm = QianfanLLMEndpoint(streaming=True)
output = llm.invoke(input_text)
output_parser = StrOutputParser()
output = output_parser.invoke(output)
#st.info(output)
return output
def get_vectordb():
# 定义 Embeddings
embedding = QianfanEmbeddingsEndpoint()
# 向量数据库持久化路径
persist_directory = './vector_db_test'
# 加载数据库
vectordb = Chroma(
persist_directory=persist_directory, # 允许我们将persist_directory目录保存到磁盘上
embedding_function=embedding
)
return vectordb
#带有历史记录的问答链
def get_chat_qa_chain(question:str):
vectordb = get_vectordb()
llm = QianfanLLMEndpoint(streaming=True)
memory = ConversationBufferMemory(
memory_key="chat_history", # 与 prompt 的输入变量保持一致。
return_messages=True # 将以消息列表的形式返回聊天记录,而不是单个字符串
)
retriever=vectordb.as_retriever()
qa = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
memory=memory
)
result = qa({"question": question})
return result['answer']
#不带历史记录的问答链
def get_qa_chain(question:str):
vectordb = get_vectordb()
llm = QianfanLLMEndpoint(streaming=True)
template = """使用以下上下文来回答最后的问题。如果你不知道答案,就说你不知道,不要试图编造答
案。最多使用三句话。尽量使答案简明扼要。总是在回答的最后说“谢谢你的提问!”。
{context}
问题: {question}
"""
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context","question"],
template=template)
qa_chain = RetrievalQA.from_chain_type(llm,
retriever=vectordb.as_retriever(),
return_source_documents=True,
chain_type_kwargs={"prompt":QA_CHAIN_PROMPT})
result = qa_chain({"query": question})
return result["result"]
# Streamlit 应用程序界面
def main():
st.title('demoLLM助手')
# 添加一个选择按钮来选择不同的模型
#selected_method = st.sidebar.selectbox("选择模式", ["qa_chain", "chat_qa_chain", "None"])
selected_method = st.radio(
"你想选择哪种模式进行对话?",
["None", "qa_chain", "chat_qa_chain"],
captions = ["不使用检索问答的普通模式", "不带历史记录的检索问答模式", "带历史记录的检索问答模式"])
# 用于跟踪对话历史
if 'messages' not in st.session_state:
st.session_state.messages = []
messages = st.container(height=300)
if prompt := st.chat_input("Say something"):
# 将用户输入添加到对话历史中
st.session_state.messages.append({"role": "user", "text": prompt})
if selected_method == "None":
# 调用 respond 函数获取回答
answer = generate_response(prompt)
elif selected_method == "qa_chain":
answer = get_qa_chain(prompt)
elif selected_method == "chat_qa_chain":
answer = get_chat_qa_chain(prompt)
# 检查回答是否为 None
if answer is not None:
# 将LLM的回答添加到对话历史中
st.session_state.messages.append({"role": "assistant", "text": answer})
# 显示整个对话历史
for message in st.session_state.messages:
if message["role"] == "user":
messages.chat_message("user").write(message["text"])
elif message["role"] == "assistant":
messages.chat_message("assistant").write(message["text"])
if __name__ == "__main__":
main()