【RAG入门必备技能】Faiss框架使用与FaissRetriever实现

Faiss介绍

faiss是一个Facebook AI团队开源的库,全称为Facebook AI Similarity Search,该开源库针对高维空间中的海量数据(稠密向量),提供了高效且可靠的相似性聚类和检索方法,可支持十亿级别向量的搜索,是目前最为成熟的近似近邻搜索库

官方资源地址https://github.com/facebookresearch/faiss

Faiss基础依赖

1)矩阵计算框架:Faiss与计算资源之间需要一个外部依赖框架,这个框架是一个矩阵计算框架,官方默认配置安装的是OpenBlas,另外也可以用Intel的MKL,相比于OpenBlas使用MKL作为框架进行编译可以提高一定的稳定性。

2)OpenMP:如果向量之间的相似性搜索是逐条进行的那计算效率会非常低,而Faiss内部实现使用了OpenMP,可以以batch的形式来进行搜素,实现计算效率的提升。

Faiss工作数据流

在使用Faiss进行query向量的相似性搜索之前,需要将原始的向量集构建封装成一个索引文件(index file)并缓存在内存中,提供实时的查询计算。在第一次构建索引文件的时候,需要经过Train和Add两个过程。后续如果有新的向量需要被添加到索引文件的话还可以有一个Add操作从而实现增量build索引。

Faiss的核心

Faiss本质上是一个向量(矢量)数据库。进行搜索时,基础是原始向量数据库,基本单位是单个向量,默认输入一个向量x,返回和x最相似的k个向量。其中的核心就是索引(index对象),Index继承了一组向量库,作用是对原始向量集进行预处理和封装,一般操作包括train和add,可以建成一个索引对象缓存在计算机内存中。所有向量在建立前需要明确向量的维度d,大多数的索引还需要训练阶段来分析向量的分布(除了IndexFlatL2)。当索引被建立就可以进行后续的search操作了。

Train:

目的:生成原向量中心点,残差(向量中心点的差值)向量中心点,部分预计算的距离

流程:

1)把原始向量分成M个子空间,针对每个子空间训练中心点(如果每个子空间的中心点为n,则pq可表达n的M次方个中心点)。

2)查找向量对应的中心点

3)向量减去对应的中心点生成残差向量

4)针对残差向量生成二级量化器。

Search:

Search操作时索引的重要部分,search方法涉及实际的相似度计算,返回的检索结果包括两个矩阵,分别为xq中元素与近邻的距离大小和近邻向量的索引序号。

使用方法

Faiss是为稠密向量提供高效相似度搜索的框架(Facebook AI Research),选择索引方式是faiss的核心内容,faiss 三个最常用的索引是:IndexFlatL2, IndexIVFFlat,IndexIVFPQ。

  • IndexFlatL2/ IndexFlatIP为最基础的精确查找。

  • IndexIVFFlat称为倒排文件索引,是使用K-means建立聚类中心,通过查询最近的聚类中心,比较聚类中的所有向量得到相似的向量,是一种加速搜索方法的索引。


  • IndexIVFPQ是一种减少内存的索引方式,IndexFlatL2和IndexIVFFlat都会全量存储所有的向量在内存中,面对大数据量,faiss提供一种基于Product Quantizer(乘积量化)的压缩算法编码向量到指定字节数来减少内存占用。但这种情况下,存储的向量是压缩过的,所以查询的距离也是近似的。


下面以为代码的方式讲解Faiss使用

构建句子向量表示模型

import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import faiss
import numpy as np
import os
from tqdm import tqdm
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

在下面的代码片段中,我们建立了一个SemanticEmbedding类,它使用预先训练的 MPNet 模型将文本编码为向量。这里我选用了sbert的多语言模型paraphrase-multilingual-mpnet-base-v2

class SemanticEmbedding:

    def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    # Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def get_embedding(self, sentences):
        # Tokenize sentences
        encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        # Perform pooling
        sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])

        # Normalize embeddings
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings.detach().numpy()

调用方法如下:

 model = SemanticEmbedding(r'I:\pretrained_models\bert\english\paraphrase-multilingual-mpnet-base-v2')
a = model.get_embedding("我喜欢打篮球")
print(a)
print(a.shape)

构建 Faiss 索引

Faiss 根据用户所需的功能提供多种不同索引选项。在下面代码中我们使用 IndexFlatIP,因为它的距离机制是内积,对于规范化嵌入而言,它与余弦相似度相同,值越大越相似。

在下面的代码中,我们创建了一个FaissIdx类,该类使用嵌入向量大小(在本例中为 768)初始化我们的索引,并使用计数器跟踪文档的 ID。请注意,Faiss 索引当前为空 — 它不包含任何文档向量。

然后向该类添加了两个方法来添加和搜索文档。为了实现这些方法,使用 Faiss 的 API 来获取文档的嵌入并搜索查询向量。请注意,搜索 API self.index.search 也接受参数 k 作为输入,它定义要返回的文档向量数量。在本例中,我们告诉该方法返回前三个文档向量,使用我们定义的 doc_map 数据结构来返回一个人类可读的文档。

class FaissIdx:

    def __init__(self, model, dim=768):
        self.index = faiss.IndexFlatIP(dim)
        # Maintaining the document data
        self.doc_map = dict()
        self.model = model
        self.ctr = 0

    def add_doc(self, document_text):
        self.index.add(self.model.get_embedding(document_text))
        self.doc_map[self.ctr] = document_text # store the original document text
        self.ctr += 1

    def search_doc(self, query, k=3):
        D, I = self.index.search(self.model.get_embedding(query), k)
        return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map]

测试索引

index = FaissIdx(model)
index.add_doc("笔记本电脑")
index.add_doc("医生的办公室")
result=index.search_doc("个人电脑")
print(result)

设置索引后,我们可以添加其他文档并对其进行搜索。在下面的代码中,我们添加文档“笔记本电脑”和“医生办公室”,然后搜索“PC 电脑”。请注意,“笔记本电脑”的相似度很高,而“医生办公室”的相似度很低,这是有道理的。请记住,余弦相似度的范围从 -1 到 1,其中 1 表示完全相似。

[{'笔记本电脑': 0.9483323}, {'医生的办公室': 0.39533424}]

构建语料索引

# 加载测试文档
data=pd.read_json('../../data/zh_refine.json', lines=True)[:50]
print(data)
print(data.columns)

for documents in tqdm(data['positive'],total=len(data)):
    for document in documents:
       index.add_doc(document)

for documents in tqdm(data['negative'],total=len(data)):
    for document in documents:
         index.add_doc(document)

result=index.search_doc("2022年特斯拉交付量")
print(result)

检索结果如下:

FaissRetriever实现

欢迎大家star正在开发的Gomate工具:https://github.com/gomate-community/GoMate

完整代码实现

https://github.com/gomate-community/GoMate/blob/main/gomate/modules/retrieval/faiss_retriever.py

import json
import os
import random
from concurrent.futures import ProcessPoolExecutor
from typing import List, Any

import faiss
import numpy as np
import tiktoken
from tqdm import tqdm

from gomate.modules.retrieval.embedding import BaseEmbeddingModel, OpenAIEmbeddingModel
from gomate.modules.retrieval.embedding import SBertEmbeddingModel
from gomate.modules.retrieval.retrievers import BaseRetriever
from gomate.modules.retrieval.utils import split_text


class FaissRetrieverConfig:
    def __init__(
            self,
            max_tokens=100,
            max_context_tokens=3500,
            use_top_k=True,
            embedding_model=None,
            question_embedding_model=None,
            top_k=5,
            tokenizer=None,
            embedding_model_string=None,
            index_path=None,
            rebuild_index=True
    ):
        if max_tokens < 1:
            raise ValueError("max_tokens must be at least 1")

        if top_k < 1:
            raise ValueError("top_k must be at least 1")

        if max_context_tokens is not None and max_context_tokens < 1:
            raise ValueError("max_context_tokens must be at least 1 or None")

        if embedding_model is not None and not isinstance(
                embedding_model, BaseEmbeddingModel
        ):
            raise ValueError(
                "embedding_model must be an instance of BaseEmbeddingModel or None"
            )

        if question_embedding_model is not None and not isinstance(
                question_embedding_model, BaseEmbeddingModel
        ):
            raise ValueError(
                "question_embedding_model must be an instance of BaseEmbeddingModel or None"
            )

        self.top_k = top_k
        self.max_tokens = max_tokens
        self.max_context_tokens = max_context_tokens
        self.use_top_k = use_top_k
        self.embedding_model = embedding_model or OpenAIEmbeddingModel()
        self.question_embedding_model = question_embedding_model or self.embedding_model
        self.tokenizer = tokenizer
        self.embedding_model_string = embedding_model_string or "OpenAI"
        self.index_path = index_path
        self.rebuild_index=rebuild_index

    def log_config(self):
        config_summary = """
		FaissRetrieverConfig:
			Max Tokens: {max_tokens}
			Max Context Tokens: {max_context_tokens}
			Use Top K: {use_top_k}
			Embedding Model: {embedding_model}
			Question Embedding Model: {question_embedding_model}
			Top K: {top_k}
			Tokenizer: {tokenizer}
			Embedding Model String: {embedding_model_string}
			Index Path: {index_path}
			Rebuild Index Path: {rebuild_index}
		""".format(
            max_tokens=self.max_tokens,
            max_context_tokens=self.max_context_tokens,
            use_top_k=self.use_top_k,
            embedding_model=self.embedding_model,
            question_embedding_model=self.question_embedding_model,
            top_k=self.top_k,
            tokenizer=self.tokenizer,
            embedding_model_string=self.embedding_model_string,
            index_path=self.index_path,
            rebuild_index=self.rebuild_index
        )
        return config_summary


class FaissRetriever(BaseRetriever):
    """
    FaissRetriever is a class that retrieves similar context chunks for a given query using Faiss.
    encoders_type is 'same' if the question and context encoder is the same,
    otherwise, encoders_type is 'different'.
    """

    def __init__(self, config):
        self.embedding_model = config.embedding_model
        self.question_embedding_model = config.question_embedding_model
        self.index = None
        self.context_chunks = []
        self.max_tokens = config.max_tokens
        self.max_context_tokens = config.max_context_tokens
        self.use_top_k = config.use_top_k
        self.tokenizer = config.tokenizer
        self.top_k = config.top_k
        self.embedding_model_string = config.embedding_model_string
        self.index_path = config.index_path
        self.rebuild_index=config.rebuild_index
        # Load the index from the specified path if it is not None
        if not self.rebuild_index:
            if self.index_path and os.path.exists(self.index_path):
                self.load_index(self.index_path)
        else:
            os.remove(self.index_path)

    def load_index(self, index_path):
        """
        Loads a Faiss index from a specified path.

        :param index_path: Path to the Faiss index file.
        """
        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
            print("Index loaded successfully.")
        else:
            print("Index path does not exist.")

    def encode_document(self, doc_text):
        """
        Builds the index from a given text.

        :param doc_text: A string containing the document text.
        """
        # Split the text into context chunks
        context_chunks = np.array(
            split_text(doc_text, self.tokenizer, self.max_tokens)
        )
        # Collect embeddings using a for loop
        embeddings = []
        for context_chunk in context_chunks:
            embedding = self.embedding_model.create_embedding(context_chunk)
            embeddings.append(embedding)

        embeddings = np.array(embeddings, dtype=np.float32)
        return embeddings,context_chunks.tolist()
    def build_from_texts(self, documents):
        """
        Processes multiple documents in batches, builds the index, and saves it to disk.

        :param documents: List of document texts to process.
        :param save_path: Path to save the index file.
        :param batch_size: Number of documents to process in each batch.
        """
        self.all_embeddings = []
        self.context_chunks=[]
        for i in tqdm(range(0, len(documents))):
            doc_embeddings,context_chunks = self.encode_document(documents[i])
            self.all_embeddings.append(doc_embeddings)
            self.context_chunks.extend(context_chunks)
        # Initialize the index only once
        if self.index is None and self.all_embeddings:
            self.index = faiss.IndexFlatIP(self.all_embeddings[0].shape[1])

        self.all_embeddings = np.vstack(self.all_embeddings)
        print(self.all_embeddings.shape)
        print(len(self.context_chunks))
        self.index.add(self.all_embeddings)
        # Save the index to disk
        faiss.write_index(self.index, self.index_path)
    def sanity_check(self, num_samples=4):
        """
        Perform a sanity check by recomputing embeddings of a few randomly-selected chunks.

        :param num_samples: The number of samples to test.
        """
        indices = random.sample(range(len(self.context_chunks)), num_samples)

        for i in indices:
            original_embedding = self.all_embeddings[i]
            recomputed_embedding = self.embedding_model.create_embedding(
                self.context_chunks[i]
            )
            assert np.allclose(
                original_embedding, recomputed_embedding
            ), f"Embeddings do not match for index {i}!"

        print(f"Sanity check passed for {num_samples} random samples.")

    def retrieve(self, query: str) -> list[Any]:
        """
        Retrieves the k most similar context chunks for a given query.

        :param query: A string containing the query.
        :param k: An integer representing the number of similar context chunks to retrieve.
        :return: A string containing the retrieved context chunks.
        """
        query_embedding = np.array(
            [
                np.array(
                    self.question_embedding_model.create_embedding(query),
                    dtype=np.float32,
                ).squeeze()
            ]
        )

        context = []

        if self.use_top_k:
            distances, indices = self.index.search(query_embedding, self.top_k)
            print(distances,indices)
            print(distances[0][2],indices)
            for i in range(self.top_k):
                context.append({'text':self.context_chunks[indices[0][i]],'score':distances[0][i]})
        else:
            range_ = int(self.max_context_tokens / self.max_tokens)
            _, indices = self.index.search(query_embedding, range_)
            total_tokens = 0
            for i in range(range_):
                tokens = len(self.tokenizer.encode(self.context_chunks[indices[0][i]]))
                context.append(self.context_chunks[indices[0][i]])
                if total_tokens + tokens > self.max_context_tokens:
                    break
                total_tokens += tokens

        return context


if __name__ == '__main__':
    from transformers import AutoTokenizer

    embedding_model_path = "/home/test/pretrained_models/bge-large-zh-v1.5"
    embedding_model = SBertEmbeddingModel(embedding_model_path)
    tokenizer = AutoTokenizer.from_pretrained(embedding_model_path)
    retriever_config = FaissRetrieverConfig(
        max_tokens=100,
        max_context_tokens=3500,
        use_top_k=True,
        embedding_model=embedding_model,
        top_k=5,
        tokenizer=tokenizer,
        embedding_model_string="bge-large-zh-v1.5",
        index_path="faiss_index.bin",
        rebuild_index=True
    )

    faiss_retriever=FaissRetriever(config=retriever_config)

    documents=[]
    with open('/home/test/codes/GoMate/data/zh_refine.json','r',encoding="utf-8") as f:
        for line in f.readlines():
            data=json.loads(line)
            documents.extend(data['positive'])
            documents.extend(data['negative'])
    print(len(documents))
    faiss_retriever.build_from_texts(documents[:200])

    contexts=faiss_retriever.retrieve("2022年冬奥会开幕式总导演是谁")
    print(contexts)

参考资料

### 基于 FAISSRAG 模型实现 #### 什么是 FAISSFAISS 是 Facebook 开发的一个高效相似度搜索库,特别适合处理大规模向量数据集。其核心功能是对高维空间中的向量进行快速近似最近邻搜索[^1]。 #### 使用 FAISS 构建 RAG 检索增强生成模型的关键步骤 以下是基于 FAISS 实现 RAG 模型的主要技术要点: 1. **构建知识库** 需要先将外部知识源转化为结构化形式并嵌入到向量表示中。这通常涉及以下过程: - 将文档分割成较小的片段(如句子或段落)。 - 利用预训练的语言模型(如 BERT 或 Sentence Transformers)提取这些片段的语义特征向量[^3]。 2. **存储索引** 使用 FAISS 创建高效的索引机制来加速检索操作。具体来说: - 初始化一个 FAISS 索引对象,并加载已计算好的向量集合。 - 对索引执行优化配置以提升查询效率,比如设置 `nlist` 和 `nprobe` 参数。 3. **实时检索** 当接收到输入时,通过编码器将其转换为目标向量,随后调用 FAISS 执行 k-NN 查找获取最接近的知识条目及其对应上下文信息。 4. **融合生成** 结合检索得到的相关内容作为条件输入传递给解码端完成最终文本生成任务[^2]。 下面给出一段 Python 示例代码展示上述流程的具体实现方式: ```python import faiss from sentence_transformers import SentenceTransformer import numpy as np # 加载预训练句向量化模型 encoder_model = SentenceTransformer('paraphrase-MiniLM-L6-v2') def build_faiss_index(corpus_texts): """构建 FAISS 索引""" embeddings = encoder_model.encode(corpus_texts, convert_to_tensor=True).cpu().numpy() index = faiss.IndexFlatL2(embeddings.shape[1]) # L2 距离衡量法 index.add(np.array(embeddings)) return index def retrieve_top_k(query_text, faiss_index, topk=5): """根据查询字符串返回前 K 条匹配项""" query_vector = encoder_model.encode([query_text], show_progress_bar=False)[0] D, I = faiss_index.search(np.expand_dims(query_vector, axis=0), topk) return [(i, d) for i, d in zip(I.flatten(), D.flatten())] corpus = ["机器学习是一门多领域交叉学科", "深度神经网络由多个层组成"] faiss_index = build_faiss_index(corpus) # 测试检索逻辑 results = retrieve_top_k("人工智能是什么?", faiss_index) for idx, dist in results: print(f"Index {idx}, Distance {dist}: {corpus[idx]}") ``` 此脚本展示了如何利用 FAISS 进行简单的检索演示,实际应用中还需要考虑更多细节问题,例如分布式部署、动态更新以及性能调优等方面的内容。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值