【RAG入门必备技能】Faiss框架使用与FaissRetriever实现

Faiss介绍

faiss是一个Facebook AI团队开源的库,全称为Facebook AI Similarity Search,该开源库针对高维空间中的海量数据(稠密向量),提供了高效且可靠的相似性聚类和检索方法,可支持十亿级别向量的搜索,是目前最为成熟的近似近邻搜索库

官方资源地址https://github.com/facebookresearch/faiss

Faiss基础依赖

1)矩阵计算框架:Faiss与计算资源之间需要一个外部依赖框架,这个框架是一个矩阵计算框架,官方默认配置安装的是OpenBlas,另外也可以用Intel的MKL,相比于OpenBlas使用MKL作为框架进行编译可以提高一定的稳定性。

2)OpenMP:如果向量之间的相似性搜索是逐条进行的那计算效率会非常低,而Faiss内部实现使用了OpenMP,可以以batch的形式来进行搜素,实现计算效率的提升。

Faiss工作数据流

在使用Faiss进行query向量的相似性搜索之前,需要将原始的向量集构建封装成一个索引文件(index file)并缓存在内存中,提供实时的查询计算。在第一次构建索引文件的时候,需要经过Train和Add两个过程。后续如果有新的向量需要被添加到索引文件的话还可以有一个Add操作从而实现增量build索引。

Faiss的核心

Faiss本质上是一个向量(矢量)数据库。进行搜索时,基础是原始向量数据库,基本单位是单个向量,默认输入一个向量x,返回和x最相似的k个向量。其中的核心就是索引(index对象),Index继承了一组向量库,作用是对原始向量集进行预处理和封装,一般操作包括train和add,可以建成一个索引对象缓存在计算机内存中。所有向量在建立前需要明确向量的维度d,大多数的索引还需要训练阶段来分析向量的分布(除了IndexFlatL2)。当索引被建立就可以进行后续的search操作了。

Train:

目的:生成原向量中心点,残差(向量中心点的差值)向量中心点,部分预计算的距离

流程:

1)把原始向量分成M个子空间,针对每个子空间训练中心点(如果每个子空间的中心点为n,则pq可表达n的M次方个中心点)。

2)查找向量对应的中心点

3)向量减去对应的中心点生成残差向量

4)针对残差向量生成二级量化器。

Search:

Search操作时索引的重要部分,search方法涉及实际的相似度计算,返回的检索结果包括两个矩阵,分别为xq中元素与近邻的距离大小和近邻向量的索引序号。

使用方法

Faiss是为稠密向量提供高效相似度搜索的框架(Facebook AI Research),选择索引方式是faiss的核心内容,faiss 三个最常用的索引是:IndexFlatL2, IndexIVFFlat,IndexIVFPQ。

  • IndexFlatL2/ IndexFlatIP为最基础的精确查找。

  • IndexIVFFlat称为倒排文件索引,是使用K-means建立聚类中心,通过查询最近的聚类中心,比较聚类中的所有向量得到相似的向量,是一种加速搜索方法的索引。


  • IndexIVFPQ是一种减少内存的索引方式,IndexFlatL2和IndexIVFFlat都会全量存储所有的向量在内存中,面对大数据量,faiss提供一种基于Product Quantizer(乘积量化)的压缩算法编码向量到指定字节数来减少内存占用。但这种情况下,存储的向量是压缩过的,所以查询的距离也是近似的。


下面以为代码的方式讲解Faiss使用

构建句子向量表示模型

import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import faiss
import numpy as np
import os
from tqdm import tqdm
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

在下面的代码片段中,我们建立了一个SemanticEmbedding类,它使用预先训练的 MPNet 模型将文本编码为向量。这里我选用了sbert的多语言模型paraphrase-multilingual-mpnet-base-v2

class SemanticEmbedding:

    def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    # Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def get_embedding(self, sentences):
        # Tokenize sentences
        encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        # Perform pooling
        sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])

        # Normalize embeddings
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings.detach().numpy()

调用方法如下:

 model = SemanticEmbedding(r'I:\pretrained_models\bert\english\paraphrase-multilingual-mpnet-base-v2')
a = model.get_embedding("我喜欢打篮球")
print(a)
print(a.shape)

构建 Faiss 索引

Faiss 根据用户所需的功能提供多种不同索引选项。在下面代码中我们使用 IndexFlatIP,因为它的距离机制是内积,对于规范化嵌入而言,它与余弦相似度相同,值越大越相似。

在下面的代码中,我们创建了一个FaissIdx类,该类使用嵌入向量大小(在本例中为 768)初始化我们的索引,并使用计数器跟踪文档的 ID。请注意,Faiss 索引当前为空 — 它不包含任何文档向量。

然后向该类添加了两个方法来添加和搜索文档。为了实现这些方法,使用 Faiss 的 API 来获取文档的嵌入并搜索查询向量。请注意,搜索 API self.index.search 也接受参数 k 作为输入,它定义要返回的文档向量数量。在本例中,我们告诉该方法返回前三个文档向量,使用我们定义的 doc_map 数据结构来返回一个人类可读的文档。

class FaissIdx:

    def __init__(self, model, dim=768):
        self.index = faiss.IndexFlatIP(dim)
        # Maintaining the document data
        self.doc_map = dict()
        self.model = model
        self.ctr = 0

    def add_doc(self, document_text):
        self.index.add(self.model.get_embedding(document_text))
        self.doc_map[self.ctr] = document_text # store the original document text
        self.ctr += 1

    def search_doc(self, query, k=3):
        D, I = self.index.search(self.model.get_embedding(query), k)
        return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map]

测试索引

index = FaissIdx(model)
index.add_doc("笔记本电脑")
index.add_doc("医生的办公室")
result=index.search_doc("个人电脑")
print(result)

设置索引后,我们可以添加其他文档并对其进行搜索。在下面的代码中,我们添加文档“笔记本电脑”和“医生办公室”,然后搜索“PC 电脑”。请注意,“笔记本电脑”的相似度很高,而“医生办公室”的相似度很低,这是有道理的。请记住,余弦相似度的范围从 -1 到 1,其中 1 表示完全相似。

[{'笔记本电脑': 0.9483323}, {'医生的办公室': 0.39533424}]

构建语料索引

# 加载测试文档
data=pd.read_json('../../data/zh_refine.json', lines=True)[:50]
print(data)
print(data.columns)

for documents in tqdm(data['positive'],total=len(data)):
    for document in documents:
       index.add_doc(document)

for documents in tqdm(data['negative'],total=len(data)):
    for document in documents:
         index.add_doc(document)

result=index.search_doc("2022年特斯拉交付量")
print(result)

检索结果如下:

FaissRetriever实现

欢迎大家star正在开发的Gomate工具:https://github.com/gomate-community/GoMate

完整代码实现

https://github.com/gomate-community/GoMate/blob/main/gomate/modules/retrieval/faiss_retriever.py

import json
import os
import random
from concurrent.futures import ProcessPoolExecutor
from typing import List, Any

import faiss
import numpy as np
import tiktoken
from tqdm import tqdm

from gomate.modules.retrieval.embedding import BaseEmbeddingModel, OpenAIEmbeddingModel
from gomate.modules.retrieval.embedding import SBertEmbeddingModel
from gomate.modules.retrieval.retrievers import BaseRetriever
from gomate.modules.retrieval.utils import split_text


class FaissRetrieverConfig:
    def __init__(
            self,
            max_tokens=100,
            max_context_tokens=3500,
            use_top_k=True,
            embedding_model=None,
            question_embedding_model=None,
            top_k=5,
            tokenizer=None,
            embedding_model_string=None,
            index_path=None,
            rebuild_index=True
    ):
        if max_tokens < 1:
            raise ValueError("max_tokens must be at least 1")

        if top_k < 1:
            raise ValueError("top_k must be at least 1")

        if max_context_tokens is not None and max_context_tokens < 1:
            raise ValueError("max_context_tokens must be at least 1 or None")

        if embedding_model is not None and not isinstance(
                embedding_model, BaseEmbeddingModel
        ):
            raise ValueError(
                "embedding_model must be an instance of BaseEmbeddingModel or None"
            )

        if question_embedding_model is not None and not isinstance(
                question_embedding_model, BaseEmbeddingModel
        ):
            raise ValueError(
                "question_embedding_model must be an instance of BaseEmbeddingModel or None"
            )

        self.top_k = top_k
        self.max_tokens = max_tokens
        self.max_context_tokens = max_context_tokens
        self.use_top_k = use_top_k
        self.embedding_model = embedding_model or OpenAIEmbeddingModel()
        self.question_embedding_model = question_embedding_model or self.embedding_model
        self.tokenizer = tokenizer
        self.embedding_model_string = embedding_model_string or "OpenAI"
        self.index_path = index_path
        self.rebuild_index=rebuild_index

    def log_config(self):
        config_summary = """
		FaissRetrieverConfig:
			Max Tokens: {max_tokens}
			Max Context Tokens: {max_context_tokens}
			Use Top K: {use_top_k}
			Embedding Model: {embedding_model}
			Question Embedding Model: {question_embedding_model}
			Top K: {top_k}
			Tokenizer: {tokenizer}
			Embedding Model String: {embedding_model_string}
			Index Path: {index_path}
			Rebuild Index Path: {rebuild_index}
		""".format(
            max_tokens=self.max_tokens,
            max_context_tokens=self.max_context_tokens,
            use_top_k=self.use_top_k,
            embedding_model=self.embedding_model,
            question_embedding_model=self.question_embedding_model,
            top_k=self.top_k,
            tokenizer=self.tokenizer,
            embedding_model_string=self.embedding_model_string,
            index_path=self.index_path,
            rebuild_index=self.rebuild_index
        )
        return config_summary


class FaissRetriever(BaseRetriever):
    """
    FaissRetriever is a class that retrieves similar context chunks for a given query using Faiss.
    encoders_type is 'same' if the question and context encoder is the same,
    otherwise, encoders_type is 'different'.
    """

    def __init__(self, config):
        self.embedding_model = config.embedding_model
        self.question_embedding_model = config.question_embedding_model
        self.index = None
        self.context_chunks = []
        self.max_tokens = config.max_tokens
        self.max_context_tokens = config.max_context_tokens
        self.use_top_k = config.use_top_k
        self.tokenizer = config.tokenizer
        self.top_k = config.top_k
        self.embedding_model_string = config.embedding_model_string
        self.index_path = config.index_path
        self.rebuild_index=config.rebuild_index
        # Load the index from the specified path if it is not None
        if not self.rebuild_index:
            if self.index_path and os.path.exists(self.index_path):
                self.load_index(self.index_path)
        else:
            os.remove(self.index_path)

    def load_index(self, index_path):
        """
        Loads a Faiss index from a specified path.

        :param index_path: Path to the Faiss index file.
        """
        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
            print("Index loaded successfully.")
        else:
            print("Index path does not exist.")

    def encode_document(self, doc_text):
        """
        Builds the index from a given text.

        :param doc_text: A string containing the document text.
        """
        # Split the text into context chunks
        context_chunks = np.array(
            split_text(doc_text, self.tokenizer, self.max_tokens)
        )
        # Collect embeddings using a for loop
        embeddings = []
        for context_chunk in context_chunks:
            embedding = self.embedding_model.create_embedding(context_chunk)
            embeddings.append(embedding)

        embeddings = np.array(embeddings, dtype=np.float32)
        return embeddings,context_chunks.tolist()
    def build_from_texts(self, documents):
        """
        Processes multiple documents in batches, builds the index, and saves it to disk.

        :param documents: List of document texts to process.
        :param save_path: Path to save the index file.
        :param batch_size: Number of documents to process in each batch.
        """
        self.all_embeddings = []
        self.context_chunks=[]
        for i in tqdm(range(0, len(documents))):
            doc_embeddings,context_chunks = self.encode_document(documents[i])
            self.all_embeddings.append(doc_embeddings)
            self.context_chunks.extend(context_chunks)
        # Initialize the index only once
        if self.index is None and self.all_embeddings:
            self.index = faiss.IndexFlatIP(self.all_embeddings[0].shape[1])

        self.all_embeddings = np.vstack(self.all_embeddings)
        print(self.all_embeddings.shape)
        print(len(self.context_chunks))
        self.index.add(self.all_embeddings)
        # Save the index to disk
        faiss.write_index(self.index, self.index_path)
    def sanity_check(self, num_samples=4):
        """
        Perform a sanity check by recomputing embeddings of a few randomly-selected chunks.

        :param num_samples: The number of samples to test.
        """
        indices = random.sample(range(len(self.context_chunks)), num_samples)

        for i in indices:
            original_embedding = self.all_embeddings[i]
            recomputed_embedding = self.embedding_model.create_embedding(
                self.context_chunks[i]
            )
            assert np.allclose(
                original_embedding, recomputed_embedding
            ), f"Embeddings do not match for index {i}!"

        print(f"Sanity check passed for {num_samples} random samples.")

    def retrieve(self, query: str) -> list[Any]:
        """
        Retrieves the k most similar context chunks for a given query.

        :param query: A string containing the query.
        :param k: An integer representing the number of similar context chunks to retrieve.
        :return: A string containing the retrieved context chunks.
        """
        query_embedding = np.array(
            [
                np.array(
                    self.question_embedding_model.create_embedding(query),
                    dtype=np.float32,
                ).squeeze()
            ]
        )

        context = []

        if self.use_top_k:
            distances, indices = self.index.search(query_embedding, self.top_k)
            print(distances,indices)
            print(distances[0][2],indices)
            for i in range(self.top_k):
                context.append({'text':self.context_chunks[indices[0][i]],'score':distances[0][i]})
        else:
            range_ = int(self.max_context_tokens / self.max_tokens)
            _, indices = self.index.search(query_embedding, range_)
            total_tokens = 0
            for i in range(range_):
                tokens = len(self.tokenizer.encode(self.context_chunks[indices[0][i]]))
                context.append(self.context_chunks[indices[0][i]])
                if total_tokens + tokens > self.max_context_tokens:
                    break
                total_tokens += tokens

        return context


if __name__ == '__main__':
    from transformers import AutoTokenizer

    embedding_model_path = "/home/test/pretrained_models/bge-large-zh-v1.5"
    embedding_model = SBertEmbeddingModel(embedding_model_path)
    tokenizer = AutoTokenizer.from_pretrained(embedding_model_path)
    retriever_config = FaissRetrieverConfig(
        max_tokens=100,
        max_context_tokens=3500,
        use_top_k=True,
        embedding_model=embedding_model,
        top_k=5,
        tokenizer=tokenizer,
        embedding_model_string="bge-large-zh-v1.5",
        index_path="faiss_index.bin",
        rebuild_index=True
    )

    faiss_retriever=FaissRetriever(config=retriever_config)

    documents=[]
    with open('/home/test/codes/GoMate/data/zh_refine.json','r',encoding="utf-8") as f:
        for line in f.readlines():
            data=json.loads(line)
            documents.extend(data['positive'])
            documents.extend(data['negative'])
    print(len(documents))
    faiss_retriever.build_from_texts(documents[:200])

    contexts=faiss_retriever.retrieve("2022年冬奥会开幕式总导演是谁")
    print(contexts)

参考资料

  • 17
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
完整版:https://download.csdn.net/download/qq_27595745/89522468 【课程大纲】 1-1 什么是java 1-2 认识java语言 1-3 java平台的体系结构 1-4 java SE环境安装和配置 2-1 java程序简介 2-2 计算机中的程序 2-3 java程序 2-4 java类库组织结构和文档 2-5 java虚拟机简介 2-6 java的垃圾回收器 2-7 java上机练习 3-1 java语言基础入门 3-2 数据的分类 3-3 标识符、关键字和常量 3-4 运算符 3-5 表达式 3-6 顺序结构和选择结构 3-7 循环语句 3-8 跳转语句 3-9 MyEclipse工具介绍 3-10 java基础知识章节练习 4-1 一维数组 4-2 数组应用 4-3 多维数组 4-4 排序算法 4-5 增强for循环 4-6 数组和排序算法章节练习 5-0 抽象和封装 5-1 面向过程的设计思想 5-2 面向对象的设计思想 5-3 抽象 5-4 封装 5-5 属性 5-6 方法的定义 5-7 this关键字 5-8 javaBean 5-9 包 package 5-10 抽象和封装章节练习 6-0 继承和多态 6-1 继承 6-2 object类 6-3 多态 6-4 访问修饰符 6-5 static修饰符 6-6 final修饰符 6-7 abstract修饰符 6-8 接口 6-9 继承和多态 章节练习 7-1 面向对象的分析与设计简介 7-2 对象模型建立 7-3 类之间的关系 7-4 软件的可维护与复用设计原则 7-5 面向对象的设计与分析 章节练习 8-1 内部类与包装器 8-2 对象包装器 8-3 装箱和拆箱 8-4 练习题 9-1 常用类介绍 9-2 StringBuffer和String Builder类 9-3 Rintime类的使用 9-4 日期类简介 9-5 java程序国际化的实现 9-6 Random类和Math类 9-7 枚举 9-8 练习题 10-1 java异常处理 10-2 认识异常 10-3 使用try和catch捕获异常 10-4 使用throw和throws引发异常 10-5 finally关键字 10-6 getMessage和printStackTrace方法 10-7 异常分类 10-8 自定义异常类 10-9 练习题 11-1 Java集合框架和泛型机制 11-2 Collection接口 11-3 Set接口实现类 11-4 List接口实现类 11-5 Map接口 11-6 Collections类 11-7 泛型概述 11-8 练习题 12-1 多线程 12-2 线程的生命周期 12-3 线程的调度和优先级 12-4 线程的同步 12-5 集合类的同步问题 12-6 用Timer类调度任务 12-7 练习题 13-1 Java IO 13-2 Java IO原理 13-3 流类的结构 13-4 文件流 13-5 缓冲流 13-6 转换流 13-7 数据流 13-8 打印流 13-9 对象流 13-10 随机存取文件流 13-11 zip文件流 13-12 练习题 14-1 图形用户界面设计 14-2 事件处理机制 14-3 AWT常用组件 14-4 swing简介 14-5 可视化开发swing组件 14-6 声音的播放和处理 14-7 2D图形的绘制 14-8 练习题 15-1 反射 15-2 使用Java反射机制 15-3 反射与动态代理 15-4 练习题 16-1 Java标注 16-2 JDK内置的基本标注类型 16-3 自定义标注类型 16-4 对标注进行标注 16-5 利用反射获取标注信息 16-6 练习题 17-1 顶目实战1-单机版五子棋游戏 17-2 总体设计 17-3 代码实现 17-4 程序的运行与发布 17-5 手动生成可执行JAR文件 17-6 练习题 18-1 Java数据库编程 18-2 JDBC类和接口 18-3 JDBC操作SQL 18-4 JDBC基本示例 18-5 JDBC应用示例 18-6 练习题 19-1 。。。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值