【Faiss】源码阅读(三)——IVFFlat(倒序索引)

摘要: 这里主要讲整个实现过程与核心思路。

1. 核心思路

前面讲的IndexFlatL2的索引方式,主要就是一种暴力搜索的方式,只是在计算的过程中针对不同的平台进行了指令集优化。

这里的IndexIVFFlat索引主要

  • 对原始m个样本随机下采样 n×256 个样本,n:表示聚类中心点个数
  • 对下采样的样本,采用kmean进行聚类
  • 对原始m个底库样本,根据聚类中心进行分桶
  • 对要查询的query,针对聚类中心进行分桶,然后采用暴力搜索的方式。

2. 测试

code

/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <cstdio>
#include <cstdlib>
#include <cassert>

#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFFlat.h>
#include <faiss/utils/utils.h>


int main() {
    int d = 64;                            // dimension
    int nb = 100000;                       // database size
    int nq = 10000;                        // nb of queries

    float *xb = new float[d * nb];
    float *xq = new float[d * nq];

    for(int i = 0; i < nb; i++) {
        for(int j = 0; j < d; j++)
            xb[d * i + j] = drand48();
        xb[d * i] += i / 1000.;
    }

    for(int i = 0; i < nq; i++) {
        for(int j = 0; j < d; j++)
            xq[d * i + j] = drand48();
        xq[d * i] += i / 1000.;
    }


    int nlist = 100;
    int k = 4;

    faiss::IndexFlatL2 quantizer(d);       // the other index
    faiss::IndexIVFFlat index(&quantizer, d, nlist, faiss::METRIC_L2);
    // here we specify METRIC_L2, by default it performs inner-product search
    double t0 = faiss::getmillisecs();
    index.verbose = 1;
    assert(!index.is_trained);
    index.train(nb, xb);
    double t1 = faiss::getmillisecs();
    printf("train time:%.3f \n", (t1-t0)/1000.0);

    assert(index.is_trained);
    index.add(nb, xb);                    // 对底库根据聚类的中心点分桶装
    double t2 = faiss::getmillisecs();
    printf("add time:%.3f \n", (t2-t1)/1000.0);

    {       // search xq
        long *I = new long[k * nq];
        float *D = new float[k * nq];

        index.search(nq, xq, k, D, I);
        double t3 = faiss::getmillisecs();
        printf("search1 time:%.3f \n", (t3-t2)/1000.0);

        printf("I=\n");
        for(int i = nq - 5; i < nq; i++) {
            for(int j = 0; j < k; j++)
                printf("%5ld ", I[i * k + j]);
            printf("\n");
        }

        index.nprobe = 10;
        index.search(nq, xq, k, D, I);
        double t4 = faiss::getmillisecs();
        printf("search2 time:%.3f \n", (t4-t3)/1000.0);

        printf("I=\n");
        for(int i = nq - 5; i < nq; i++) {
            for(int j = 0; j < k; j++)
                printf("%5ld ", I[i * k + j]);
            printf("\n");
        }

        delete [] I;
        delete [] D;
    }



    delete [] xb;
    delete [] xq;

    return 0;
}

Training level-1 quantizer
Training level-1 quantizer on 100000 vectors in 64D
Training IVF residual
IndexIVF: no residual training
train time:0.190
IndexIVFFlat::add_core: added 100000 / 100000 vectors
add time:0.074
search1 time:0.044
I=
10827 10004 10049 10147
10267 10880 10330 10156
9896 10093 10361 10184
8603 9895 9946 9335
10123 11099 10876 9647
search2 time:0.202
I=
10842 10827 9938 10004
9403 10267 10880 10330
9896 10146 10093 10361
8603 10523 10582 9895
11460 10123 11099 10876

nprobe改变之后对首位搜索结果有影响。查找聚类中心的个数,默认为1个,若nprobe=nlist则等同于精确查找.
对nprobe×k个搜索结果进行重排序,找出距离最小的k个。为什么会有nprobe×k个搜索结果?因为我们不能完全信任level1的搜索结果,level1的最近邻聚类中心对应的key中并不一定包含level2的最近邻,为了保险期间,我们扩大对level1的信任范围,取最近的nprobe个聚类中心,在它们对应的子数组中分别搜索k近邻,最后再对整个结果进行重排。来源

3. 实现细节

  • 对底库数据做 n×265的随机下采样,用kmeans做聚类训练
void Clustering::train (idx_t nx, const float *x_in, Index & index) {
    FAISS_THROW_IF_NOT_FMT (nx >= k,
             "Number of training points (%ld) should be at least "
             "as large as number of clusters (%ld)", nx, k);

    double t0 = getmillisecs();

    // yes it is the user's responsibility, but it may spare us some
    // hard-to-debug reports.
    for (size_t i = 0; i < nx * d; i++) {
      FAISS_THROW_IF_NOT_MSG (finite (x_in[i]),
                        "input contains NaN's or Inf's");       // 输入数值检查
    }

    const float *x = x_in;
    ScopeDeleter<float> del1;

    if (nx > k * max_points_per_centroid) {                     // 默认分支,k=100,max_points_per_centroid=256
        if (verbose)
            printf("Sampling a subset of %ld / %ld for training\n",
                   k * max_points_per_centroid, nx);
        std::vector<int> perm (nx);
        rand_perm (perm.data (), nx, seed);
        nx = k * max_points_per_centroid;               // 100个点,每个点256个样本。总样本数
        float * x_new = new float [nx * d];
        for (idx_t i = 0; i < nx; i++)
            memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d);      // 随机下采样
        x = x_new;
        del1.set (x);
    } else if (nx < k * min_points_per_centroid) {
        fprintf (stderr,
                 "WARNING clustering %ld points to %ld centroids: "
                 "please provide at least %ld training points\n",
                 nx, k, idx_t(k) * min_points_per_centroid);
    }


    if (nx == k) {
        if (verbose) {
            printf("Number of training points (%ld) same as number of "
                   "clusters, just copying\n", nx);
        }
        // this is a corner case, just copy training set to clusters
        centroids.resize (d * k);
        memcpy (centroids.data(), x_in, sizeof (*x_in) * d * k);
        index.reset();
        index.add(k, x_in);
        return;
    }


    if (verbose)
        printf("Clustering %d points in %ldD to %ld clusters, "
               "redo %d times, %d iterations\n",
               int(nx), d, k, nredo, niter);

    idx_t * assign = new idx_t[nx];
    ScopeDeleter<idx_t> del (assign);
    float * dis = new float[nx];
    ScopeDeleter<float> del2(dis);

    // for redo
    float best_err = HUGE_VALF;
    std::vector<float> best_obj;
    std::vector<float> best_centroids;

    // support input centroids

    FAISS_THROW_IF_NOT_MSG (
       centroids.size() % d == 0,
       "size of provided input centroids not a multiple of dimension");

    size_t n_input_centroids = centroids.size() / d;        // n_input_centroids=0,输入的中心点数

    if (verbose && n_input_centroids > 0) {
        printf ("  Using %zd centroids provided as input (%sfrozen)\n",
                n_input_centroids, frozen_centroids ? "" : "not ");
    }

    double t_search_tot = 0;
    if (verbose) {
        printf("  Preprocessing in %.2f s\n",
               (getmillisecs() - t0) / 1000.);
    }
    t0 = getmillisecs();

    for (int redo = 0; redo < nredo; redo++) {              // nredo=1

        if (verbose && nredo > 1) {
            printf("Outer iteration %d / %d\n", redo, nredo);
        }

        // initialize remaining centroids with random points from the dataset
        centroids.resize (d * k);                           // 中心点的存储空间
        std::vector<int> perm (nx);                 // 中心聚类的总样本数

        rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
        for (int i = n_input_centroids; i < k ; i++)        // 随机初始化聚类中心
            memcpy (&centroids[i * d], x + perm[i] * d,
                    d * sizeof (float));

        post_process_centroids ();

        if (index.ntotal != 0) {
            index.reset();
        }

        if (!index.is_trained) {
            index.train (k, centroids.data());          // 没有训练
        }

        index.add (k, centroids.data());                // 中心点
        float err = 0;
        for (int i = 0; i < niter; i++) {               // k-mean循环
            double t0s = getmillisecs();
            index.search (nx, x, 1, dis, assign);       // 计算聚类样本和中心点的距离,每个聚类样本很某个中心点的最小距离/索引
            InterruptCallback::check();
            t_search_tot += getmillisecs() - t0s;       // 时间

            err = 0;
            for (int j = 0; j < nx; j++)                // 距离求和
                err += dis[j];
            obj.push_back (err);

            int nsplit = km_update_centroids (          // 更新中心点
                  x, centroids.data(),
                  assign, d, k, nx, frozen_centroids ? n_input_centroids : 0);

            if (verbose) {
                printf ("  Iteration %d (%.2f s, search %.2f s): "
                        "objective=%g imbalance=%.3f nsplit=%d       \r",
                        i, (getmillisecs() - t0) / 1000.0,
                        t_search_tot / 1000,
                        err, imbalance_factor (nx, k, assign),
                        nsplit);
                fflush (stdout);
            }

            post_process_centroids ();

            index.reset ();
            if (update_index)       // update_index=false
                index.train (k, centroids.data());

            assert (index.ntotal == 0);
            index.add (k, centroids.data());            // 将聚类中心点放入quantizer的底库
            InterruptCallback::check ();
        }
        if (verbose) printf("\n");
        if (nredo > 1) {
            if (err < best_err) {
                if (verbose)
                    printf ("Objective improved: keep new clusters\n");
                best_centroids = centroids;
                best_obj = obj;
                best_err = err;
            }
            index.reset ();
        }
    }
    if (nredo > 1) {
        centroids = best_centroids;
        obj = best_obj;
        index.reset();
        index.add(k, best_centroids.data());
    }

}
  • 将m个底库样本根据聚类样本分桶
FAISS_THROW_IF_NOT (is_trained);
    assert (invlists);
    FAISS_THROW_IF_NOT_MSG (!(maintain_direct_map && xids),
                            "cannot have direct map and add with ids");
    const int64_t * idx;
    ScopeDeleter<int64_t> del;

    if (precomputed_idx) {
        idx = precomputed_idx;
    } else {
        int64_t * idx0 = new int64_t [n];
        del.set (idx0);
        quantizer->assign (n, x, idx0);         // 计算query和聚类中心的匹配关系
        idx = idx0;
    }
    int64_t n_add = 0;
    for (size_t i = 0; i < n; i++) {
        int64_t id = xids ? xids[i] : ntotal + i;
        int64_t list_no = idx [i];              // 匹配的聚类中心的索引

        if (list_no < 0)
            continue;
        const float *xi = x + i * d;
        size_t offset = invlists->add_entry (
              list_no, id, (const uint8_t*) xi);    // 将样本加到聚类中心

        if (maintain_direct_map)
            direct_map.push_back (list_no << 32 | offset);
        n_add++;
    }
    if (verbose) {
        printf("IndexIVFFlat::add_core: added %ld / %ld vectors\n",
               n_add, n);
    }
    ntotal += n;
  • 查询。首先计算待查询样本query和聚类中心的匹配。然后再某聚类中心桶中进行暴力搜索
// 聚类中心
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);            // nprobe=1
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);

double t0 = getmillisecs();
quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());  // 计算与聚类中心的距离/匹配关系
indexIVF_stats.quantization_time += getmillisecs() - t0;

t0 = getmillisecs();
invlists->prefetch_lists (idx.get(), n * nprobe);           // 未做任何操作

search_preassigned (n, x, k, idx.get(), coarse_dis.get(),       // 在分桶中进行暴力搜索
                  distances, labels, false);
indexIVF_stats.search_time += getmillisecs() - t0;
// 分桶中暴力搜索
    long nprobe = params ? params->nprobe : this->nprobe;
    long max_codes = params ? params->max_codes : this->max_codes;

    size_t nlistv = 0, ndis = 0, nheap = 0;

    using HeapForIP = CMin<float, idx_t>;
    using HeapForL2 = CMax<float, idx_t>;

    bool interrupt = false;

    // don't start parallel section if single query
    bool do_parallel =
        parallel_mode == 0 ? n > 1 :
        parallel_mode == 1 ? nprobe > 1 :
        nprobe * n > 1;

#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
     {
        InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);        // 获得倒序索引
        ScopeDeleter1<InvertedListScanner> del(scanner);

        /*****************************************************
         * Depending on parallel_mode, there are two possible ways
         * to organize the search. Here we define local functions
         * that are in common between the two
         ******************************************************/

        // intialize + reorder a result heap

        auto init_result = [&](float *simi, idx_t *idxi) {              // 定义一个匿名函数,参数按引用传递
            if (metric_type == METRIC_INNER_PRODUCT) {                  // 用于simi,idxi的初始化
                heap_heapify<HeapForIP> (k, simi, idxi);
            } else {
                heap_heapify<HeapForL2> (k, simi, idxi);
            }
        };

        auto reorder_result = [&] (float *simi, idx_t *idxi) {          // simi,idxi排序用
            if (metric_type == METRIC_INNER_PRODUCT) {
                heap_reorder<HeapForIP> (k, simi, idxi);
            } else {
                heap_reorder<HeapForL2> (k, simi, idxi);
            }
        };

        // single list scan using the current scanner (with query
        // set porperly) and storing results in simi and idxi
        auto scan_one_list = [&] (idx_t key, float coarse_dis_i,        // 
                                  float *simi, idx_t *idxi) {

            if (key < 0) {
                // not enough centroids for multiprobe
                return (size_t)0;
            }
            FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist,
                                    "Invalid key=%ld nlist=%ld\n",      // key聚类中心点的索引
                                    key, nlist);

            size_t list_size = invlists->list_size(key);                // 聚类中心点的样本数

            // don't waste time on empty lists
            if (list_size == 0) {
                return (size_t)0;
            }

            scanner->set_list (key, coarse_dis_i);

            nlistv++;

            InvertedLists::ScopedCodes scodes (invlists, key);          // 聚类中心样本的数值

            std::unique_ptr<InvertedLists::ScopedIds> sids;
            const Index::idx_t * ids = nullptr;

            if (!store_pairs)  {
                sids.reset (new InvertedLists::ScopedIds (invlists, key));      // 聚类中心样本的索引
                ids = sids->get();
            }

            nheap += scanner->scan_codes (list_size, scodes.get(),
                                          ids, simi, idxi, k);          // simi,idxi用于存放和query匹配的样本的距离和索引

            return list_size;
        };

        /****************************************************
         * Actual loops, depending on parallel_mode
         ****************************************************/

        if (parallel_mode == 0) {

#pragma omp for
            for (size_t i = 0; i < n; i++) {

                if (interrupt) {
                    continue;
                }

                // loop over queries
                scanner->set_query (x + i * d);         // 写入query
                float * simi = distances + i * k;
                idx_t * idxi = labels + i * k;

                init_result (simi, idxi);

                long nscan = 0;

                // loop over probes
                for (size_t ik = 0; ik < nprobe; ik++) {

                    nscan += scan_one_list (                // 单样本的查询
                         keys [i * nprobe + ik],
                         coarse_dis[i * nprobe + ik],
                         simi, idxi
                    );

                    if (max_codes && nscan >= max_codes) {
                        break;
                    }
                }

                ndis += nscan;
                reorder_result (simi, idxi);            // 对simi,idxi排序

                if (InterruptCallback::is_interrupted ()) {
                    interrupt = true;
                }

            } // parallel for

3. 其他

  1. 相关注释的代码
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值