pip install langchain langchain_community transformers InstructorEmbedding sentence_transformers==2.2.2 faiss-gpu PyPDF2 streamlit pyngrok gradio fitz frontend
import os
# 设置环境变量
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 检查环境变量是否已更新
!huggingface-cli download --resume-download BAAI/bge-m3 --token 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
!huggingface-cli download --resume-download baichuan-inc/Baichuan2-7B-Chat --token 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
# coding: utf-8
# Author: 唐国梁Tommy
# Date: 2023-08-06
import streamlit as st
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS, Milvus, Pinecone, Chroma
from langchain.memory import ConversationBufferMemory
from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
import streamlit as st
from PyPDF2 import PdfReader
def main():
# 配置界面
st.set_page_config(page_title="基于PDF文档的 QA ChatBot",
st.header("基于LangChain+LLM实现QA ChatBot")
# 参考官网链接:https://github.com/hwchase17/langchain-streamlit-template/blob/master/main.py
# 初始化
# session_state是Streamlit提供的用于存储会话状态的功能
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
# 1. 提供用户输入文本框
user_input = st.text_input("基于上传的PDF文档,请输入你的提问: ")
# 处理用户输入,并返回响应结果
if user_input:
with st.sidebar:
# 2. 设置子标题
# 3. 上传文档
files = st.file_uploader("上传PDF文档,然后点击'提交并处理'",
if st.button("提交并处理"):
with st.spinner("请等待,处理中..."):
# 4. 获取PDF文档内容(文本)
texts = extract_text_from_PDF(files)
# 5. 将获取到的文档内容进行切分
content_chunks = split_content_into_chunks(texts)
# st.write(content_chunks)
# 6. 对每个chunk计算embedding,并存入到向量数据库
# 6.1 根据model_type和model_name创建embedding model对象
#embedding_model = get_openaiEmbedding_model()
# embedding_model = get_huggingfaceEmbedding_model(model_name="BAAI/bge-m3")
embedding_model = HuggingFaceInstructEmbeddings(model_name="BAAI/bge-m3")
# 6.2 创建向量数据库对象,并将文本embedding后存入到里面
vector_store = save_chunks_into_vectorstore(content_chunks, embedding_model)
# 7. 创建对话chain
# 官网链接:https://python.langchain.com/docs/modules/memory/types/buffer
st.session_state.conversation = get_chat_chain(vector_store)
def extract_text_from_PDF(files):
# 参考官网链接:https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf
# 加载多个PDF文件
text = ""
for pdf in files:
pdf_reader = PdfReader(pdf)
for page in pdf_reader.pages:
text += page.extract_text()
return text
def split_content_into_chunks(text):
# 参考官网链接:https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/character_text_splitter
text_spliter = CharacterTextSplitter(separator="\n",
chunks = text_spliter.split_text(text)
return chunks
def save_chunks_into_vectorstore(content_chunks, embedding_model):
# 参考官网链接:https://python.langchain.com/docs/modules/data_connection/vectorstores/
# pip install faiss-gpu (如果没有GPU,那么 pip install faiss-cpu)
vectorstore = FAISS.from_texts(texts=content_chunks,
return vectorstore
def get_chat_chain(vector_store):
# ① 获取 LLM model
#llm = get_openai_model()
# llm = get_huggingfacehub(model_name="google/flan-t5-xxl")
# llm = get_huggingfacehub(model_name="google-bert/bert-base-chinese")
model_path = "baichuan-inc/Baichuan2-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_path,trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
llm = HuggingFacePipeline(pipeline=pipe)
# ② 存储历史记录
# 参考官网链接:https://python.langchain.com/docs/use_cases/question_answering/how_to/chat_vector_db
# 用于缓存或者保存对话历史记录的对象
memory = ConversationBufferMemory(
memory_key='chat_history', return_messages=True)
# ③ 对话链
conversation_chain = ConversationalRetrievalChain.from_llm(
search_kwargs={"k": 5}
return conversation_chain
def process_user_input(user_input):
print('输入内容 '+user_input)
if st.session_state.conversation is not None:
# 调用函数st.session_state.conversation,并把用户输入的内容作为一个问题传入,返回响应。
response = st.session_state.conversation({'question': user_input})
print('response '+response)
# session状态是Streamlit中的一个特性,允许在用户的多个请求之间保存数据。
st.session_state.chat_history = response['chat_history']
# 显示聊天记录
# chat_history : 一个包含之前聊天记录的列表
for i, message in enumerate(st.session_state.chat_history):
# 用户输入
if i % 2 == 0:
"{{MSG}}", message.content), unsafe_allow_html=True) # unsafe_allow_html=True表示允许HTML内容被渲染
# 机器人响应
"{{MSG}}", message.content), unsafe_allow_html=True)
if __name__ == "__main__":
from pyngrok import ngrok
# 使用 ngrok 将本地的 Gradio 服务器端口转发到公共 URL
public_url = ngrok.connect(addr="8501", proto="http")
tunnels = ngrok.get_tunnels()
print("tunnels:", tunnels)
# 输出公共 URL
#print("Public URL:", public_url)
!streamlit run /mnt/workspace/main.py
