结论速递
这个迷你项目手搓了一个最小的RAG系统。之前基于Langchain实现过RAG(不用chain),对RAG结构还算熟悉,因此核心放在构思如何手搓和对照思路与TinyRAG的实现上。
TinyRAG项目中几个使用langchain得不到的小收获:
- 使用JSON做persistent
- cosine similarity的计算和加速
- chunk的切割方法
目录
1 绪论
1.1 RAG
RAG全称:Retrieval-Augmented Generation 检索增强生成。含义很直观,就是用检索到的内容增强LLM的回答能力。
RAG实际上诞生于LLM爆火之前,在最早的论文里,retrieval那一块是要梯度更新训练的(Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks)。23年初LLM应用开发爆火之后,RAG被用来给LLM进行knowledge injection,给LLM外挂知识库,让它能针对特定问题做开卷问答,以降低回答的知识幻觉。
Naive RAG的实现非常非常简单,但是要做一个好的RAG系统并不容易,有很多小细节都可能影响RAG的效果。另外,目前的RAG的一些变体,比如在检索前初筛,或者检索后rerank等(如下图,出自Retrieval-Augmented Generation for Large Language Models: A Survey)其实也是在原流程上进行了一些改进得来的。因此,深入理解RAG的结构就变得非常重要。这也是参与手搓RAG的原因所在。
1.2 提前思考:如何手搓RAG
因为之前已经完成过RAG的项目,对RAG的流程比较有经验。这里尝试先根据已有经验梳理手搓RAG需要完成哪些工作。
数据流程图出自去年5月本人完成的一个使用Langchain实现的RAG项目:GitHub链接
可以考虑一个RAG的项目代码需要实现的核心部分包括:
- Basis:
- Embedding:把文字变成向量,方便计算文本相似度实现检索;检索完把向量复原为文字
- Similarity calculation:相似文本检索。因为文字已经变成向量,所以这里通过计算向量相似度实现。
- Vector database creation:
- File loader:加载文件
- Text splitter:针对给定的分隔符,将文件切割成一个一个的段落
- Vector database:存储embedded后的段落,实现输入query后检索段落
- Retrieval:根据query在vector database当中检索内容
- Chatbot:
- Prompt template:能往里填内容的prompt模板,需要填的内容包括:检索出来的文本信息,用户的query,历史问答(如果有)
- ChatLLM(Chatbase):调用LLM,给出回复。
- Chatbot:起到一个interface的作用,根据用户query给结果
参考Langchain,ChatLLM的过程和Chatbot是两个分开的封装,前者是输入结构化prompt,输出结果;后者是输入用户query,内部调用检索,构建prompt,调用LLM,最后输出结果
2 TinyRAG
2.1 项目结构
TinyRAG项目结构由5大部分构成:
- Embedding:向量化模块,用来将文档片段向量化。包含了Similarity calculation的实现,实现为abstract类的方法,被所有子类继承。
- Load and split:文档加载和切分的模块,用来加载文档并切分成文档片段。结合了File loader和Text splitter两部分的实现。
- Vector database:数据库,用来存放文档片段和对应的向量表示。
- Retrieval:检索模块,用来根据 Query 在检索相关的文档片段。
- LLM:大模型模块,用来根据检索出来的文档回答用户的问题。对应ChatLLM,包含了Prompt template的实现。
项目结构的思维导图如下:
除了在组成上有些出入,基本包含了前述1.2中的考虑。Chatbot部分是在demo中实现的,没有封装。
2.2 代码阅读
2.2.1 Embedding
完整代码见:Embeddings.py 代码地址,以下是笔记
- 通过调用Embedding模型的API实现embedding的核心功能
- 实现结构是abstract类定义接口和公用方法(余弦相似度计算,调用API实现embedding),然后覆写不同的子类继承
class BaseEmbeddings:
"""
Base class for embeddings
"""
def __init__(self, path: str, is_api: bool) -> None:
self.path = path
self.is_api = is_api
def get_embedding(self, text: str, model: str) -> List[float]:
# 调embedding API
...
@classmethod
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
# 余弦相似度计算
...
- 分别实现了调用OpenAI,Jina,Zhipu,Dashscope的embedding API的子类
2.2.2 Load and split
完整代码见:utils.py 代码地址,以下是笔记
- 核心类是
ReadFiles
,里面杂糅了几大功能:- Load:读文件(PDF,markdown,txt,通过classmethod实现的)
- 单文件Split:根据分词符和token数量切割,同时实现了text overlap的功能,是在
get_chunk
这个方法中实现的 - 多文件load和chunk拼接:
get_content
方法,对每个文件,都read_file_content
读文本,然后调用get_chunk
,然后把获得的chunk拼接起来。
这个类功能太混杂,看起来很累,一点修改的小建议(或者我也可以提pr):
- 把单文件和多文件loader分开
- 单文件loader也是abstract类(规范化接口)和子类的形式(不同文件格式,设定不同分隔符),把chunk和load放这里
- 多文件loader就放filelist和get content(然后既然已经实现多文件了,完全可以顺手再加一个metadata,就放个文件名也好)
另外,get_chunk
里token的用法我有怀疑,这里应该不是token是字数?
def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150):
...
2.2.3 Vector database & Retrieval
完整代码见VectorBase.py 代码地址,以下是笔记
一个VectorStore类实现了:
- text->embedding :
get_vector
,对每一个chunk,调用embedding实现 - Retrieval:
get_similarity
,计算两个vector的similarity(这里可以考虑写成内部方法)query
,根据query检索指定数量的relevant chunk,因为多chunk,用的numpy加速similarity的计算
def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:
query_vector = EmbeddingModel.get_embedding(query)
result = np.array([self.get_similarity(query_vector, vector)
for vector in self.vectors])
return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()
- 永久化存储:
persist
,用的JSON
存 - 从永久化地址读取:
load_vector
2.2.4 LLM
完整代码见:LLM.py 代码地址,以下是笔记
PROMPT_TEMPLATE
是用dict
实现的,结合字符串的format填入context和query实现格式化的prompt。- LLMChat是abstract基类定义接口和不同模型子类继承的方式
class BaseModel:
def __init__(self, path: str = '') -> None:
self.path = path
def chat(self, prompt: str, history: List[dict], content: str) -> str:
pass
def load_model(self):
pass
- 实现了OpenAI,InternLM,DashscopeChat,和ZhipuChat四种LLMChat