简介
在本文中,我将介绍如何使用“增强检索生成”(Retrieval-Augmented Generation,简称 RAG)技术来生成文章。RAG 是一种结合了检索模型与生成模型的技术,通过从知识库中检索相关信息,可以生成高质量的内容。
GitHub 仓库:
目录
- 环境配置
- 数据准备
- 检索
- 文章生成
- 结论
1. 环境配置
首先确保安装了以下库:
1pip install langchain langchain-openai langchain-text-splitters langchain-community langchain-core flag-embedding
然后导入所需的模块:
1from langchain.llms import Tongyi
2from langchain_openai import ChatOpenAI
3from langchain_text_splitters import RecursiveCharacterTextSplitter
4from langchain_community.vectorstores import FAISS
5from langchain.retrievers import BM25Retriever, EnsembleRetriever
6from langchain.embeddings import HuggingFaceEmbeddings
7from langchain.document_loaders import TextLoader
8from langchain_core.prompts import PromptTemplate
9from flag_embedding import FlagReranker
10from langchain_community.document_loaders import DirectoryLoader
11from langchain_community.vectorstores.utils import DistanceStrategy
12from tqdm import tqdm
13import numpy as np
2. 数据准备
数据准备包括两个步骤:文档分割和文本向量化。数据分割是将文案分割为长度事宜的长度,既保证不丢失太多信息,也适合输入大模型。数据向量化则是用embedding model将分割后的文本向量化,方便搜索。
1from langchain_community.document_loaders import DirectoryLoader
2
3# 加载数据
4data_loader = DirectoryLoader("your_path_to_load_data", recursive=True)
5raw_knowledge_base = data_loader.load()
6
7markdown_separators = [
8 "\n#{1,6} ",
9 "```\n",
10 "\n\\*\\*\\*+\n",
11 "\n---+\n",
12 "\n___+\n",
13 "\n\n",
14 "\n",
15 " ",
16 "",
17 ".",
18 ",",
19 "\u200B", # 零宽度空格
20 "\uff0c", # 全角逗号
21 "\u3001", # 理想逗号
22 "\uff0e", # 全角句点
23 "\u3002",
24]
25
26text_splitter = RecursiveCharacterTextSplitter(
27 chunk_size=1000,
28 chunk_overlap=100,
29 add_start_index=True,
30 strip_whitespace=True,
31 separators=markdown_separators,
32)
33
34data = []
35for doc in tqdm(raw_knowledge_base):
36 data.extend(text_splitter.split_documents([doc]))
37
38# 去除重复项
39data_index = {}
40refined_data = []
41for doc in data:
42 if doc.page_content not in data_index:
43 data_index[doc.page_content] = True
44 refined_data.append(doc)
45
46embedding_model = HuggingFaceEmbeddings(
47 model_name='BAAI/bge-m3',
48 cache_folder="your_path_to_save_model",
49 multi_process=True,
50 model_kwargs={"device": "cuda:0", "trust_remote_code": True},
51 encode_kwargs={"normalize_embeddings": True},
52)
53
54knowledge_vector_database = FAISS.from_documents(refined_data, embedding_model, distance_strategy=DistanceStrategy.COSINE)
3. 检索
为了检索信息,我们采用了一个综合的方法,即向量相似度和 BM25 相结合的方式,这两种方法可以互相补充。第一个很简单,在上面我们已经将数据向量化之后,直接计算向量间的相似值。第二个使用的是BM25,简单来说是tfidf的升级版。使用复合retriever的原因则是能够相互弥补。第一个retriever有时并不能直接返回包含关键词的资料,正好bm25可以弥补。接着,我们使用一个更复杂的模型对检索出的文档进行重排序。
1retriever_vectordb = knowledge_vector_database.as_retriever()
2keyword_retriever = BM25Retriever.from_documents(refined_data, k=10)
3ensemble_retriever = EnsembleRetriever(retrievers=[retriever_vectordb, keyword_retriever], weights=[0.5, 0.5])
4
5reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True, cache_dir="/home/wuyujie/model")
6
7def get_relevant_knowledge(query, k=100, final_extract_num=20):
8 relevant_doc = ensemble_retriever.invoke(query)
9
10 temp_docs = [[query, doc.page_content] for doc in relevant_doc]
11 scores = reranker.compute_score(temp_docs)
12
13 top_indices = np.argsort(scores)[-final_extract_num:][::-1]
14 final_retrieval = [relevant_doc[i].page_content for i in top_indices]
15
16 return final_retrieval
4. 文章生成
为了生成文章,我们建议使用“思考链”提示方法(Chain-of-Thought Prompting)。首先基于主题生成一个大纲,然后针对大纲中的关键词进行二次检索。最后,根据大纲和检索到的信息生成文章。
完整的代码我已经放在了下面的 GitHub 链接中,有兴趣的朋友可以查看。
5. 结论
RAG 技术是一种强大的内容生成方法,它可以通过结合检索和生成模型来利用大量的数据集生成高质量的文章。欢迎尝试这些代码并根据自己的需要进行调整。
以下是写成class的代码。
from langchain.llms import tongyi
from langchain_openai import ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import alibabacloud_opensearch, faiss
from langchain_community.embeddings import baidu_qianfan_endpoint
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import TextLoader
from langchain_core.prompts import PromptTemplate
from FlagEmbedding import FlagReranker
from langchain_community.document_loaders import DirectoryLoader, text
from langchain_community.vectorstores.utils import DistanceStrategy
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModelForSequenceClassification
from tqdm import tqdm
import numpy as np
import torch
class RAGSystem:
def __init__(self, data_directory,
model_name="gpt-4o",
embedding_model_name='BAAI/bge-m3',
cache_folder="your_model_save_path",
vector_db_path="your_vec_model_save_path"):
# vector_db_path,向量数据库路径
# temperature越高文案生成越宽泛,越低越严格,(0.0,2.0)
self.llm = ChatOpenAI(model_name=model_name,
api_key="your_api",
temperature=1.2,
)
self.data_directory = data_directory
self.text_splitter = self._initialize_text_splitter()
self.RDATA = self._load_and_split_data(self.data_directory)
self.embedding_model = self._initialize_embedding_model(embedding_model_name, cache_folder)
self.cosine_knowledge_vector_database = self._load_vector_db(vector_db_path)
self.retriever_vectordb = self.cosine_knowledge_vector_database.as_retriever()
self.keyword_retriever = BM25Retriever.from_documents(self.RDATA, k=50, bm25_params={'k1': 1.5})
self.ensemble_retriever = EnsembleRetriever(retrievers=[self.retriever_vectordb, self.keyword_retriever], weights=[0.5, 0.5])
self.reranker = self._initialize_reranker()
def _initialize_text_splitter(self):
MARKDOWN_SEPARATORS = [
"\n#{1,6} ", "```\n", "\n\\*\\*\\*+\n", "\n---+\n", "\n___+\n", "\n\n", "\n", " ", "", ".", ",",
"\u200B", "\uff0c", "\u3001", "\uff0e", "\u3002"
]
return RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
add_start_index=True,
strip_whitespace=True,
separators=MARKDOWN_SEPARATORS,
)
def _load_and_split_data(self, path):
data_loader = DirectoryLoader(path, glob="*.txt", recursive=True, loader_cls=text.TextLoader)
raw_knowledge_base = data_loader.load()
data = []
print("正在加载数据库")
for doc in tqdm(raw_knowledge_base):
data += self.text_splitter.split_documents([doc])
# Remove duplicates
data_index = {}
RDATA = []
for doc in data:
if doc.page_content not in data_index:
data_index[doc.page_content] = True
RDATA.append(doc)
return RDATA
def _initialize_embedding_model(self, model_name, cache_folder):
print("正在加载embedding模型")
return HuggingFaceEmbeddings(
model_name=model_name,
cache_folder=cache_folder,
multi_process=True,
model_kwargs={"device": "cuda:0", "trust_remote_code": True},
encode_kwargs={"normalize_embeddings": True},
)
def _load_vector_db(self, vector_db_path):
print("正在加载向量数据库")
cosine_knowledge_vector_database = faiss.FAISS.load_local(vector_db_path, embeddings=self.embedding_model, allow_dangerous_deserialization=True)
return cosine_knowledge_vector_database
def _initialize_reranker(self):
return FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True, cache_dir="your_model_save_path")
def vectorize_data(self):
return faiss.FAISS.from_documents(self.RDATA, self.embedding_model, distance_strategy=DistanceStrategy.COSINE)
def add_vectorize_data(self, new_data_path, vec_data_path, save_path):
# 将数据库向量化,非必须,每次加入新数据库时使用
# data_path,新数据文件夹路径
# vec_data_path, 旧向量数据库路径
# save_path, 新向量数据库路径
print("正在合并数据库")
knowledge_vector_database = self._load_vector_db(vec_data_path)
RDATA = self._load_and_split_data(new_data_path)
add_ed_index = knowledge_vector_database.add_documents(RDATA)
knowledge_vector_database.save_local(save_path)
def get_relevant_knowledge(self, query, final_extract_num=20):
relevant_doc = self.ensemble_retriever.invoke(query)
temp_docs = [[query, doc.page_content] for doc in relevant_doc]
score = self.reranker.compute_score(temp_docs)
index = np.argsort(score)[-final_extract_num:][::-1]
final_retrieval = [relevant_doc[i].page_content for i in index]
return final_retrieval
def generate_script(self, query):
print("正在生成大纲")
query_retrieval = self.get_relevant_knowledge(query)
context = ''.join([f'文件{i}: {doc}' for i, doc in enumerate(query_retrieval)])
script_prompt_template = PromptTemplate.from_template('''
参考资料:
{context}
你的大纲。
主题:
{query}
答案:
''')
chain = script_prompt_template | self.llm
script = chain.invoke({'query': query, 'context': context})
return script.content
def extract_keywords(self, query, script):
# 提取关键词
important_extraction_prompt = PromptTemplate.from_template('''
以下是一篇根据主题写的大纲,提取其中的关键词以供后续在资料库中搜索。
主题:
{query}
大纲:
{script}
''')
chain = important_extraction_prompt | self.llm
important = chain.invoke({"query": query, "script": script})
return important.content
def generate_article(self, query):
# 围绕主题生成大纲
script = self.generate_script(query)
# 提取大纲中的关键词
important = self.extract_keywords(query, script)
# 根据关键词二次搜索
important_retrieval = self.get_relevant_knowledge(important)
context = ''.join([f'文件{i}: {doc}' for i, doc in enumerate(important_retrieval)])
print("正在生成文案")
prompt_template = PromptTemplate.from_template('''
参考资料:
{context}
你的prompt。
大纲:
{script}
答案:
''')
chain = prompt_template | self.llm
article = chain.invoke({'script': script, 'context': context})
return article.content
# Example usage:
# data_directory是文本资料库路径
rag_system = RAGSystem(data_directory="your_data_path")
query = ""
answer = rag_system.generate_article(query)
# script = rag_system.generate_script(query)
# keywords = rag_system.extract_keywords(query, script)
# important_retrieval = rag_system.get_relevant_knowledge(keywords, final_extract_num=30)
GitHub 仓库:
如果您有任何疑问或想要进一步讨论,请随时留言或直接联系我。