faiss IndexIVFPQ 源码详解 - train

train整体流程简介

    1. 目标:生成原向量中心点,残差(向量中心点的差值)向量中心点,部分预计算的距离(因为采用的是pq方式,所以原始向量会分成M个子空间,这里训练的都是子空间的中心点)。
    2. 残差:向量中心点的差值
    3. 为什么训练残差向量中心点:因为残差向量更集中,误差更小。
    4. PQ: 把向量空间划分为M个子空间,在检索时从每个子空间找最邻近的中心点,假如每个子空间有n个中心点,则pq可表达的中点点为n的m次方个。
    5. 流程
      1. 把原始向量分成M个子空间,针对每个子空间训练中心点(如果每个子空间的中心点为n,则pq可表达n的M次方个中心点)。
      2. 查找向量对应的中心点
      3. 向量减去对应的中心点生成残差向量
      4. 针对残差向量生成二级量化器。

 

距离计算方式

339 /** Precomputed tables for residuals

 340  *

 341  * During IVFPQ search with by_residual, we compute

 342  *

 343  *     d = || x - y_C - y_R ||^2 , x - y_c:表示的残差,y_r:表示残差对应的中心点

 344  *

 345  * where x is the query vector, y_C the coarse centroid, y_R the

 346  * refined PQ centroid. The expression can be decomposed as:

 347  *

 348  *    d = || x - y_C ||^2 + || y_R ||^2 + 2 * (y_C|y_R) - 2 * (x|y_R)

 349  *        ---------------   ---------------------------       -------

 350  *             term 1                 term 2                   term 3

 351  *

 352  * When using multiprobe, we use the following decomposition:

 353  * - term 1 is the distance to the coarse centroid, that is computed

 354  *   during the 1st stage search.

 355  * - term 2 can be precomputed, as it does not involve x. However,

 356  *   because of the PQ, it needs nlist * M * ksub storage. This is why

 357  *   use_precomputed_table is off by default

 358  * - term 3 is the classical non-residual distance table.

 359  *

 360  * Since y_R defined by a product quantizer, it is split across

 361  * subvectors and stored separately for each subvector. If the coarse

 362  * quantizer is a MultiIndexQuantizer then the table can be stored

 363  * more compactly.

 

 

 

69 /*n:向量个数, x:整个向量*/

  70 void IndexIVFPQ::train_residual_o (idx_t n, const float *x, float *residuals_2)

  71 {

  72     const float * x_in = x;

  73     /*d:全空间维度*/

  74     x = fvecs_maybe_subsample (

  75          d, (size_t*)&n, pq.cp.max_points_per_centroid * pq.ksub,

  76          x, verbose, pq.cp.seed); ///抽样(出入向量个数 > 最大向量数),抽取pq.cp.max_points_per_centroid * pq.ksub个

  77

  78     ScopeDeleter<float> del_x (x_in == x ? nullptr : x);

  79

  80     const float *trainset;

  81     ScopeDeleter<float> del_residuals;

  82     if (by_residual) { ///猜测是编码

  83         if(verbose) printf("computing residuals\n");

  84         idx_t * assign = new idx_t [n]; // assignement to coarse centroids

  85         ScopeDeleter<idx_t> del (assign);

  86         quantizer->assign (n, x, assign); ///quantizer:一级量化器,获得assign,                                                            

     assign是编码,每个检索向量都对应topk的中心点,这里就是中心点对应的编码(默认k=1)

  87         float *residuals = new float [n * d];

  88         del_residuals.set (residuals);

  89         for (idx_t i = 0; i < n; i++)

  90            quantizer->compute_residual (x + i * d, residuals+i*d, assign[i]); ///向量和对应中心点向量的差(每个向量在各个分量上的差)

  91

  92         trainset = residuals;

  93     } else {

  94         trainset = x;

  95     }

  96     if (verbose)

  97         printf ("training %zdx%zd product quantizer on %ld vectors in %dD\n",

  98                 pq.M, pq.ksub, n, d);

  99     pq.verbose = verbose; ///pq二级量化器

 100     pq.train (n, trainset); ///向量和中心点向量的差值(为了减少距离误差,因为残差更收敛)

 101

 102     if (do_polysemous_training) {

 103         if (verbose)

 104             printf("doing polysemous training for PQ\n");

 105         PolysemousTraining default_pt;

 106         PolysemousTraining *pt = polysemous_training;

 107         if (!pt) pt = &default_pt;

 108         pt->optimize_pq_for_hamming (pq, n, trainset); ///hamming:字符距离

 109     }

 110

 111     // prepare second-level residuals for refine PQ

 112     if (residuals_2) { ///二级残差

 113         uint8_t *train_codes = new uint8_t [pq.code_size * n];

 114         ScopeDeleter<uint8_t> del (train_codes);

 115         pq.compute_codes (trainset, train_codes, n); ///

116

 117         for (idx_t i = 0; i < n; i++) {

 118             const float *xx = trainset + i * d;

 119             float * res = residuals_2 + i * d;

 120             pq.decode (train_codes + i * pq.code_size, res);

 121             for (int j = 0; j < d; j++)

 122                 res[j] = xx[j] - res[j];

 123         }

 124

 125     }

 126

 127     if (by_residual) {

 128         precompute_table (); ///为减少误差,会把向量映射到残差空间,求和残差中心点的距离,这里计算一些和搜索相关无关的值

 129     }

 130

 131 }

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值