TinyRAG 实现指南
本文将带领大家一步一步实现一个简单的RAG模型,即Tiny-RAG。Tiny-RAG是RAG的简化版本,仅包含RAG的核心功能:检索(Retrieval)和生成(Generation)。通过构建Tiny-RAG,帮助大家更好地理解RAG模型的原理和实现。
1. RAG 介绍
大语言模型(LLM)会产生误导性的“幻觉”,依赖的信息可能过时,处理特定知识时效率不高,缺乏专业领域的深度洞察,并且在推理能力上有所欠缺。检索增强生成(Retrieval-Augmented Generation,RAG)应时而生,通过在生成答案之前从广泛的文档数据库中检索相关信息,引导生成过程,提高内容的准确性和相关性,缓解了幻觉问题。
RAG的基本结构包括:
- 向量化模块:将文档片段向量化。
- 文档加载和切分模块:加载文档并切分成文档片段。
- 数据库:存放文档片段和对应的向量表示。
- 检索模块:根据Query检索相关的文档片段。
- 大模型模块:根据检索出来的文档回答用户的问题。
2. 向量化
首先,实现一个向量化类,用于将文档片段向量化。设置一个Embedding基类,便于扩展其他模型。
class BaseEmbeddings:
def __init__(self, path: str, is_api: bool) -> None:
self.path = path
self.is_api = is_api
def get_embedding(self, text: str, model: str) -> List[float]:
raise NotImplementedError
@classmethod
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
dot_product = np.dot(vector1, vector2)
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
return dot_product / magnitude if magnitude else 0
继承基类编写具体的Embedding类,例如使用OpenAI的Embedding API:
class OpenAIEmbedding(BaseEmbeddings:
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
from openai import OpenAI
self.client = OpenAI()
self.client.api_key = os.getenv("OPENAI_API_KEY")
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
if self.is_api:
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
3. 文档加载和切分
实现文档加载和切分类,将文档切分成片段。支持pdf、md、txt文件。
def read_file_content(cls, file_path: str):
if file_path.endswith('.pdf'):
return cls.read_pdf(file_path)
elif file_path.endswith('.md'):
return cls.read_markdown(file_path)
elif file_path.endswith('.txt'):
return cls.read_text(file_path)
else:
raise ValueError("Unsupported file type")
def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150):
chunk_text = []
curr_len = 0
curr_chunk = ''
lines = text.split('\n')
for line in lines:
line = line.replace(' ', '')
line_len = len(enc.encode(line))
if curr_len + line_len <= max_token_len:
curr_chunk += line + '\n'
curr_len += line_len + 1
else:
chunk_text.append(curr_chunk)
curr_chunk = curr_chunk[-cover_content:] + line
curr_len = line_len + cover_content
if curr_chunk:
chunk_text.append(curr_chunk)
return chunk_text
4. 数据库 && 向量检索
设计一个向量数据库,存放文档片段和向量表示,并实现检索模块。
class VectorStore:
def __init__(self, document: List[str] = ['']) -> None:
self.document = document
def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]:
pass
def persist(self, path: str = 'storage'):
pass
def load_vector(self, path: str = 'storage'):
pass
def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:
query_vector = EmbeddingModel.get_embedding(query)
result = np.array([self.get_similarity(query_vector, vector) for vector in self.vectors])
return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()
5. 大模型模块
实现大模型模块,根据检索出来的文档回答用户的问题。以InternLM2-chat-7B模型为例:
class BaseModel:
def __init__(self, path: str = '') -> None:
self.path = path
def chat(self, prompt: str, history: List[dict], content: str) -> str:
pass
def load_model(self):
pass
class InternLMChat(BaseModel):
def __init__(self, path: str = '') -> None:
super().__init__(path)
self.load_model()
def chat(self, prompt: str, history: List = [], content: str='') -> str:
prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)
response, history = self.model.chat(self.tokenizer, prompt, history)
return response
def load_model(self):
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
PROMPT_TEMPLATE = dict(
InternLM_PROMPT_TEMPALTE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
问题: {question}
可参考的上下文:
···
{context}
···
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
有用的回答:"""
)
6. LLM Tiny-RAG Demo
示例代码展示了如何将各模块结合,完成Tiny-RAG的搭建和使用。
from RAG.VectorBase import VectorStore
from RAG.utils import ReadFiles
from RAG.LLM import OpenAIChat, InternLMChat
from RAG.Embeddings import JinaEmbedding, ZhipuEmbedding
# 没有保存数据库
docs = ReadFiles('./data').get_content(max_token_len=600, cover_content=150)
vector = VectorStore(docs)
embedding = ZhipuEmbedding()
vector.get_vector(EmbeddingModel=embedding)
vector.persist(path='storage')
question = 'git的原理是什么?'
content = vector.query(question, model='zhipu', k=1)[0]
chat = InternLMChat(path='model_path')
print(chat.chat(question, [], content))
7. 总结
Tiny-RAG模型的实现帮助我们理解了RAG的基本原理和流程。一个最小的RAG结构包括向量化模块、文档加载和切分模块、数据库、向量检索和大模型模块。通过这些模块的实现,可以显著提升生成内容的准确性和相关性。