【Faiss】基础索引类型(六)

基础索引类型

数据准备

import numpy as np 
d = 512          #维数
n_data = 2000   
np.random.seed(0) 
data = []
mu = 3
sigma = 0.1
for i in range(n_data):
    data.append(np.random.normal(mu, sigma, d))
data = np.array(data).astype('float32')

#query
query = []
n_query = 10
np.random.seed(12) 
query = []
for i in range(n_query):
    query.append(np.random.normal(mu, sigma, d))
query = np.array(query).astype('float32')

#导入faiss
import sys
sys.path.append('/home/maliqi/faiss/python/')
import faiss

1.精确搜索(Exact Search for L2)

一种暴力搜索方法,遍历数据库中的每一个向量与查询向量对比。

index = faiss.IndexFlatL2(d)
# index = faiss.index_factory(d, "Flat") #两种定义方式
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.61838   8.782156  8.782816  8.832029  8.837633  8.848496  8.897978
  8.916636  8.919006  8.9374   ]
 [9.033303  9.038907  9.091705  9.15584   9.164591  9.200112  9.201884
  9.220335  9.279477  9.312859 ]
 [8.063818  8.211029  8.306456  8.373352  8.459253  8.459892  8.498557
  8.546464  8.555408  8.621426 ]
 [8.193894  8.211956  8.34701   8.446963  8.45299   8.45486   8.473572
  8.50477   8.513636  8.530684 ]
 [8.369624  8.549444  8.704066  8.736764  8.760082  8.777319  8.831345
  8.835486  8.858271  8.860058 ]
 [8.299072  8.432398  8.434382  8.457374  8.539217  8.562359  8.579033
  8.618736  8.630861  8.643393 ]
 [8.615004  8.615164  8.72604   8.730943  8.762621  8.796932  8.797068
  8.797365  8.813985  8.834726 ]
 [8.377227  8.522776  8.711159  8.724562  8.745737  8.763846  8.768602
  8.7727995 8.786856  8.828224 ]
 [8.342917  8.488056  8.655106  8.662771  8.701336  8.741287  8.743608
  8.770507  8.786264  8.849051 ]
 [8.522164  8.575703  8.68462   8.767247  8.782909  8.850494  8.883733
  8.90369   8.909393  8.91768  ]]

2.精确搜索(Exact Search for Inner Product)

当数据库向量是标准化的,计算返回的distance就是余弦相似度。

index = faiss.IndexFlatIP(d)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[4621.749  4621.5464 4619.745  4619.381  4619.177  4618.0615 4617.169
  4617.0566 4617.0483 4616.631 ]
 [4637.3975 4637.288  4635.368  4635.2446 4634.881  4633.608  4633.0215
  4632.7637 4632.56   4632.373 ]
 [4621.756  4621.4697 4619.7485 4619.5615 4619.424  4618.0186 4616.9927
  4616.962  4616.901  4616.735 ]
 [4623.6074 4623.5596 4621.3965 4621.158  4620.906  4619.838  4618.9756
  4618.9126 4618.7695 4618.478 ]
 [4625.553  4625.0645 4623.461  4623.196  4622.957  4621.337  4620.7373
  4620.717  4620.5635 4620.2485]
 [4628.489  4628.449  4626.491  4626.487  4625.6406 4624.6143 4624.29
  4624.     4623.7524 4623.618 ]
 [4637.7466 4637.338  4635.3047 4635.125  4634.748  4633.0137 4632.864
  4632.58   4632.3027 4632.2324]
 [4630.472  4630.333  4628.264  4627.9375 4627.738  4626.8965 4625.814
  4625.7227 4625.4443 4625.091 ]
 [4635.7715 4635.489  4633.6904 4633.568  4632.658  4631.463  4631.4307
  4631.101  4630.99   4630.3066]
 [4625.6753 4625.558  4623.454  4623.3926 4623.324  4622.2827 4621.7783
  4621.1157 4620.905  4620.854 ]]

3.(Hierarchical Navigable Small World graph exploration)

返回近似结果。

index = faiss.IndexHNSWFlat(d,16)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.61838   8.832029  8.848496  8.897978  8.916636  8.9374    8.9597
  8.962785  8.984709  8.998907 ]
 [9.038907  9.164591  9.200112  9.201884  9.220335  9.312859  9.34434
  9.344851  9.416974  9.421429 ]
 [8.306456  8.373352  8.459253  8.546464  8.631898  8.63715   8.63917
  8.713682  8.735945  8.7704735]
 [8.193894  8.211956  8.34701   8.45486   8.473572  8.50477   8.513636
  8.530684  8.545482  8.617173 ]
 [8.369624  8.760082  8.831345  8.858271  8.860058  8.862642  8.936951
  8.996922  8.998444  9.022133 ]
 [8.299072  8.432398  8.434382  8.539217  8.562359  8.698317  8.753672
  8.768751  8.779131  8.780444 ]
 [8.615004  8.615164  8.730943  8.797365  8.861536  8.885755  8.911812
  8.922768  8.942963  8.980488 ]
 [8.377227  8.522776  8.711159  8.724562  8.745737  8.768602  8.7727995
  8.786856  8.828224  8.879469 ]
 [8.342917  8.488056  8.662771  8.741287  8.743608  8.770507  8.857255
  8.893716  8.932134  8.933593 ]
 [8.575703  8.68462   8.850494  8.883733  8.90369   8.909393  8.91768
  8.936615  8.961668  8.977329 ]]

4.倒排表搜索(Inverted file with exact post-verification)

快速入门部分介绍过。

nlist = 50
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.837633  9.122337  9.217627  9.362019  9.39345   9.396795  9.401556
  9.446939  9.52043   9.5279255]
 [9.436286  9.636714  9.707813  9.714355  9.734249  9.809814  9.87722
  9.960412  9.978079  9.982276 ]
 [8.621426  8.658703  8.842339  8.862192  8.891519  8.937078  8.972767
  8.98658   9.007745  9.088661 ]
 [8.211956  8.735372  8.747662  8.800873  8.917062  9.1208725 9.178852
  9.215968  9.2192    9.265095 ]
 [8.858271  8.998444  9.041813  9.0883045 9.159481  9.169218  9.187948
  9.203735  9.204121  9.256811 ]
 [8.434382  8.539217  8.630861  8.753672  8.768751  8.794859  8.815165
  8.817884  8.8404    8.848925 ]
 [8.861536  8.878873  8.942963  8.944212  8.9446945 8.95914   8.980488
  9.051479  9.059914  9.081419 ]
 [9.15522   9.423113  9.432117  9.465836  9.529045  9.554071  9.556268
  9.638275  9.656209  9.69151  ]
 [8.743608  8.902418  9.065649  9.201052  9.223066  9.223073  9.247414
  9.269661  9.288244  9.291237 ]
 [8.936615  9.077     9.152468  9.1537075 9.313195  9.314999  9.373196
  9.400535  9.434517  9.445862 ]]

5.LSH(Locality-Sensitive Hashing (binary flat index))

nbits = 2 * d
index = faiss.IndexLSH(d, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[ 8. 10. 10. 10. 10. 10. 10. 11. 11. 11.]
 [ 7.  8.  9.  9.  9. 10. 10. 10. 10. 10.]
 [ 7.  8.  8.  9.  9.  9.  9.  9.  9.  9.]
 [ 9.  9. 10. 11. 12. 12. 12. 12. 12. 12.]
 [ 6.  6.  6.  7.  7.  8.  8.  8.  8.  8.]
 [ 8.  8.  8.  9.  9.  9.  9.  9. 10. 10.]
 [ 6.  7.  8.  8.  9.  9.  9.  9.  9.  9.]
 [ 9.  9.  9.  9.  9.  9.  9.  9.  9. 10.]
 [ 7.  8.  8.  8.  8.  8.  8.  9.  9.  9.]
 [ 9.  9.  9. 10. 10. 10. 10. 10. 10. 10.]]

6.SQ量化(Scalar quantizer (SQ) in flat mode)

index = faiss.IndexScalarQuantizer(d, 4)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.623227  8.777792  8.785317  8.828824  8.83549   8.845292  8.896896
  8.914818  8.922382  8.934983 ]
 [9.028506  9.037546  9.099248  9.1526165 9.16542   9.19639   9.200499
  9.224975  9.274046  9.3053875]
 [8.064029  8.21301   8.310526  8.376435  8.457833  8.462002  8.501087
  8.550647  8.556992  8.624525 ]
 [8.19665   8.210531  8.346436  8.444769  8.452809  8.454114  8.4745245
  8.496618  8.510042  8.525612 ]
 [8.370452  8.547959  8.704323  8.733619  8.763926  8.776738  8.829511
  8.835644  8.857149  8.859046 ]
 [8.29591   8.432422  8.435944  8.454732  8.542395  8.565367  8.579683
  8.621871  8.632034  8.644775 ]
 [8.609016  8.612934  8.72663   8.734133  8.758857  8.797326  8.797966
  8.798654  8.815295  8.8382225]
 [8.378947  8.521084  8.711153  8.726161  8.748383  8.759655  8.768218
  8.769182  8.792372  8.834644 ]
 [8.340463  8.48951   8.659344  8.664954  8.702756  8.741513  8.741941
  8.768993  8.781276  8.852154 ]
 [8.520282  8.574987  8.683459  8.769213  8.7820425 8.85128   8.881118
  8.906741  8.907756  8.924014 ]]

7.PQ量化(Product quantizer (PQ) in flat mode)

M = 8 #必须是d的因数
nbits = 6  #只能是8, 12, 16
index = faiss.IndexPQ(d, M, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[5.3184814 5.33667   5.3638916 5.366333  5.3704834 5.4000244 5.404663
  5.415283  5.425659  5.427246 ]
 [5.6835938 5.686035  5.687134  5.7489014 5.76062   5.7731934 5.7766113
  5.7875977 5.798828  5.7991943]
 [4.902588  5.0057373 5.0323486 5.036255  5.045044  5.048828  5.0498047
  5.0499268 5.072998  5.0737305]
 [4.844116  4.850586  4.868042  4.8946533 4.8997803 4.8999023 4.902954
  4.909546  4.9210205 4.921875 ]
 [5.279419  5.333252  5.3344727 5.3431396 5.35083   5.357422  5.366211
  5.3862305 5.38855   5.3936768]
 [5.019409  5.048706  5.0942383 5.1052246 5.116455  5.157593  5.159424
  5.168457  5.171875  5.194092 ]
 [5.0563965 5.0909424 5.1367188 5.1534424 5.1724854 5.199951  5.20105
  5.2144775 5.214966  5.23938  ]
 [5.16333   5.173706  5.2418213 5.265259  5.265869  5.274414  5.291382
  5.307495  5.309204  5.310425 ]
 [5.1501465 5.2508545 5.291992  5.3186035 5.3205566 5.328369  5.336548
  5.3479004 5.35376   5.360962 ]
 [5.2751465 5.2772217 5.279663  5.3304443 5.350708  5.3571777 5.3669434
  5.373047  5.373413  5.382324 ]]

8.倒排表乘积量化(IVFADC (coarse quantizer+PQ on residuals))

M = 8
nbits = 4
nlist = 50
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, M, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[5.1985765 5.209732  5.233874  5.237282  5.2553835 5.262968  5.270462
  5.2895284 5.2908745 5.302353 ]
 [5.5696826 5.5942397 5.611737  5.6186624 5.619787  5.643144  5.646076
  5.676093  5.682111  5.6982036]
 [4.7446747 4.824335  4.834736  4.844829  4.850663  4.853364  4.867393
  4.873641  4.8785725 4.88787  ]
 [4.783175  4.797909  4.8491716 4.85687   4.857151  4.8586845 4.860058
  4.866444  4.868099  4.885188 ]
 [5.1260395 5.134188  5.1386065 5.141901  5.1756086 5.192538  5.1938267
  5.1975694 5.199704  5.2012296]
 [4.882325  4.900981  4.9040375 4.911916  4.916094  4.923492  4.928433
  4.928472  4.937878  4.95728  ]
 [4.9729834 4.976016  4.984484  5.0074816 5.0200887 5.0217285 5.029479
  5.029899  5.0346465 5.0349855]
 [5.1357193 5.147153  5.1525207 5.189519  5.217377  5.220489  5.2341766
  5.239973  5.2411985 5.253551 ]
 [5.0623484 5.087064  5.1075807 5.109309  5.110051  5.1330123 5.1387715
  5.1431603 5.151037  5.1516275]
 [5.12455   5.163775  5.1762547 5.185327  5.190364  5.19723   5.2099175
  5.2115583 5.214532  5.2182474]]

cell-probe方法

为了加速索引过程,经常采用划分子类空间(如k-means)的方法,虽然这样无法保证最后返回的结果是完全正确的。先划分子类空间,再在部分子空间中搜索的方法,就是cell-probe方法。
具体流程为:
1)数据集空间被划分为n个部分,在k-means中,表现为n个类;
2)每个类中的向量保存在一个倒排表中,共有n个倒排表;
3)查询时,选中nprobe个倒排表;
4)将这几个倒排表中的向量与查询向量作对比。
在这种方法中,只需要排查数据库中的一部分向量,大约只有nprobe/n的数据,因为每个倒排表的长度并不一致(每个类中的向量个数不一定相等)。

cell-probe粗量化

在一些索引类型中,需要一个Flat index作为粗量化器,如IndexIVFFlat,在训练的时候会将类中心保存在Flat index中,在add和search阶段,会首先判定将其落入哪个类空间。在search阶段,nprobe参数需要调整以权衡检索精度与检索速度。
实验表明,对高维数据,需要维持比较高的nprobe数值才能保证精度。

与LSH的优劣

LSH也是一种cell-probe方法,与其相比,LSH有一下一点不足:
1)LSH需要大量的哈希方程,会带来额外的内存开销;
2)哈希函数不适合输入数据。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值