faiss原理(Product Quantization)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

import numpy as np
from scipy.cluster.vq import vq, kmeans2
from scipy.spatial.distance import cdist


def train(vec, M, Ks):
    """
    :param vec: 向量
    :param M: 子向量组数
    :param Ks: 每组向量聚类个数
    :return: codeword: [M, Ks, Ds],
        codeword[m][k]表示第m组子向量第k个子向量所属的聚类中心向量
    """
    Ds = int(vec.shape[1] / M)
    codeword = np.empty((M, Ks, Ds), np.float32)

    for m in range(M):
        vec_sub = vec[:, m * Ds: (m + 1) * Ds]
        # 第m组子向量vec_sub聚成Ks类
        # kmeans2返回两个结果,第一个是原始向量归属类目的中心向量,第二个是类目ID
        codeword[m], label = kmeans2(vec_sub, Ks)
    return codeword


def encode(codeword, vec):
    """
    :param codeword: 码本,shape为[M, Ks, Ds]
    :param vec: 原始向量
    :return: pqcode: pq编码结果,
        shape为[N, M],每个原始向量用M组子向量的聚类中心ID表示
    """
    M, Ks, Ds = codeword.shape
    # pq编码shape为[N, M]
    pqcode = np.empty((vec.shape[0], M), np.int64)
    for m in range(M):
        vec_sub = vec[:, m * Ds: (m + 1) * Ds]
        # 第m组子向量
        # 第m组子向量中每个子向量在第m个码本中查找距离最近的
        pqcode[:, m], dist = vq(vec_sub, codeword[m])
    return pqcode


def search(codeword, pqcode, query):
    """
    :param codeword:
    :param pqcode: pq编码结果, shape为[N, M],每个原始向量用M组子向量的聚类中心ID表示
    :param query: 查询向量[1, d]
    :return: dist:查询向量与原始向量的距离,shape为[N,]
    """
    M, Ks, Ds = codeword.shape
    # 距离向量表, [M, Ks]
    dist_table = np.empty((M, Ks))
    for m in range(M):
        query_sub = query[m * Ds: (m + 1) * Ds]
        # query_sub向量与第m个码本每个向量距离
        dist_table[m, :] = cdist([query_sub], codeword[m], 'sqeuclidean')[0]

    # dist_table[range(M), pqcode] 为 query向量与原始向量在每个子向量的聚类,shape为[N, M]
    # 每组子向量距离相加
    dist = np.sum(dist_table[range(M), pqcode], axis=1)
    return dist


def main():
    # 数据量
    N = 50000
    # 向量维度
    d = 128
    # 每组子向量聚类个数
    Ks = 32
    # 训练向量[N, d]
    vec_train = np.random.random((N, d))
    # 查询向量[1, d]
    # mock 第100个是距离查询向量最近的
    selected_vec = vec_train[100]
    query_vec = selected_vec + [np.random.uniform(-0.001, 0.001) for _ in range(d)]
    query = np.random.random((1, d))
    # 子向量组数
    M = 4

    # 对原始向量划分子向量组,并对每组子向量进行聚类
    codeword = train(vec_train, M, Ks)
    # pq编码
    pqcode = encode(codeword, vec_train)
    # 查询向量
    dist = search(codeword, pqcode, query_vec)

    sorted_dist = sorted(enumerate(dist), key=lambda x: x[1])
    print(sorted_dist[0])
    """
    (100, 6.850794458722508)
    """
# -*- coding:utf-8 -*-


import faiss
import numpy as np


def test_IndexFlatL2(vec_train, query, top_k=5):
    """
    暴力检索
    :param vec_train:
    :param query:
    :return:
    """
    N, d = vec_train.shape
    # 1. 创建索引
    index = faiss.IndexFlatL2(d)
    # 2. 添加数据集
    index.add(vec_train)
    # 3. 检索
    dist_list, label_list = index.search(np.array([query]), k=top_k)
    print(dist_list, label_list)


def test_IndexIVFFlat(vec_train, query_vec, top_k=5):
    """
    通过创建倒排索引优化
    流程:
    使用k-means对train向量进行聚类,查询时query_vec所归属的类目中进行检索
    :param vec_train:
    :param query_vec:
    :param top_k:
    :return:
    """
    nlist = 100  # 聚类中心的个数
    N, d = vec_train.shape
    quantizer = faiss.IndexFlatL2(d)  # the other index
    index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
    # 添加 训练集
    index.train(vec_train)
    index.add(vec_train)
    # 检索
    res = index.search(query_vec, k=top_k)
    print(res)


def main():
    # 数据量
    N = 50000
    # 向量维度
    d = 128
    vec_train = np.ascontiguousarray(np.random.random((N, d)), np.float32)

    # mock 第100个是距离查询向量最近的
    selected_vec = vec_train[100]
    query_vec = selected_vec + [np.random.uniform(-0.001, 0.001) for _ in range(d)]
    query_vec = np.ascontiguousarray(query_vec, np.float32)
    # 1. 暴力检索,全量检索
    # test_IndexFlatL2(vec_train, query_vec)
    # 2. 倒排索引
    test_IndexIVFFlat(vec_train, query_vec)


if __name__ == '__main__':
    main()

参考
[1] 实例理解product quantization算法
[2] 【关于 Faiss 】 那些的你不知道的事-技术圈
[3] Home · facebookresearch/faiss Wiki
[4] A Survey of Product Quantization
[5] Product quantization for nearest neighbor search
[6] 理解 product quantization 算法
[7] https://github.com/matsui528/na

转子:https://zhuanlan.zhihu.com/p/534004381

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WitsMakeMen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值