基础索引类型
数据准备
import numpy as np
d = 512 #维数
n_data = 2000
np.random.seed(0)
data = []
mu = 3
sigma = 0.1
for i in range(n_data):
data.append(np.random.normal(mu, sigma, d))
data = np.array(data).astype('float32')
#query
query = []
n_query = 10
np.random.seed(12)
query = []
for i in range(n_query):
query.append(np.random.normal(mu, sigma, d))
query = np.array(query).astype('float32')
#导入faiss
import sys
sys.path.append('/home/maliqi/faiss/python/')
import faiss
1.精确搜索(Exact Search for L2)
一种暴力搜索方法,遍历数据库中的每一个向量与查询向量对比。
index = faiss.IndexFlatL2(d)
# index = faiss.index_factory(d, "Flat") #两种定义方式
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.61838 8.782156 8.782816 8.832029 8.837633 8.848496 8.897978 8.916636 8.919006 8.9374 ] [9.033303 9.038907 9.091705 9.15584 9.164591 9.200112 9.201884 9.220335 9.279477 9.312859 ] [8.063818 8.211029 8.306456 8.373352 8.459253 8.459892 8.498557 8.546464 8.555408 8.621426 ] [8.193894 8.211956 8.34701 8.446963 8.45299 8.45486 8.473572 8.50477 8.513636 8.530684 ] [8.369624 8.549444 8.704066 8.736764 8.760082 8.777319 8.831345 8.835486 8.858271 8.860058 ] [8.299072 8.432398 8.434382 8.457374 8.539217 8.562359 8.579033 8.618736 8.630861 8.643393 ] [8.615004 8.615164 8.72604 8.730943 8.762621 8.796932 8.797068 8.797365 8.813985 8.834726 ] [8.377227 8.522776 8.711159 8.724562 8.745737 8.763846 8.768602 8.7727995 8.786856 8.828224 ] [8.342917 8.488056 8.655106 8.662771 8.701336 8.741287 8.743608 8.770507 8.786264 8.849051 ] [8.522164 8.575703 8.68462 8.767247 8.782909 8.850494 8.883733 8.90369 8.909393 8.91768 ]]
2.精确搜索(Exact Search for Inner Product)
当数据库向量是标准化的,计算返回的distance就是余弦相似度。
index = faiss.IndexFlatIP(d)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[4621.749 4621.5464 4619.745 4619.381 4619.177 4618.0615 4617.169 4617.0566 4617.0483 4616.631 ] [4637.3975 4637.288 4635.368 4635.2446 4634.881 4633.608 4633.0215 4632.7637 4632.56 4632.373 ] [4621.756 4621.4697 4619.7485 4619.5615 4619.424 4618.0186 4616.9927 4616.962 4616.901 4616.735 ] [4623.6074 4623.5596 4621.3965 4621.158 4620.906 4619.838 4618.9756 4618.9126 4618.7695 4618.478 ] [4625.553 4625.0645 4623.461 4623.196 4622.957 4621.337 4620.7373 4620.717 4620.5635 4620.2485] [4628.489 4628.449 4626.491 4626.487 4625.6406 4624.6143 4624.29 4624. 4623.7524 4623.618 ] [4637.7466 4637.338 4635.3047 4635.125 4634.748 4633.0137 4632.864 4632.58 4632.3027 4632.2324] [4630.472 4630.333 4628.264 4627.9375 4627.738 4626.8965 4625.814 4625.7227 4625.4443 4625.091 ] [4635.7715 4635.489 4633.6904 4633.568 4632.658 4631.463 4631.4307 4631.101 4630.99 4630.3066] [4625.6753 4625.558 4623.454 4623.3926 4623.324 4622.2827 4621.7783 4621.1157 4620.905 4620.854 ]]
3.(Hierarchical Navigable Small World graph exploration)
返回近似结果。
index = faiss.IndexHNSWFlat(d,16)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.61838 8.832029 8.848496 8.897978 8.916636 8.9374 8.9597 8.962785 8.984709 8.998907 ] [9.038907 9.164591 9.200112 9.201884 9.220335 9.312859 9.34434 9.344851 9.416974 9.421429 ] [8.306456 8.373352 8.459253 8.546464 8.631898 8.63715 8.63917 8.713682 8.735945 8.7704735] [8.193894 8.211956 8.34701 8.45486 8.473572 8.50477 8.513636 8.530684 8.545482 8.617173 ] [8.369624 8.760082 8.831345 8.858271 8.860058 8.862642 8.936951 8.996922 8.998444 9.022133 ] [8.299072 8.432398 8.434382 8.539217 8.562359 8.698317 8.753672 8.768751 8.779131 8.780444 ] [8.615004 8.615164 8.730943 8.797365 8.861536 8.885755 8.911812 8.922768 8.942963 8.980488 ] [8.377227 8.522776 8.711159 8.724562 8.745737 8.768602 8.7727995 8.786856 8.828224 8.879469 ] [8.342917 8.488056 8.662771 8.741287 8.743608 8.770507 8.857255 8.893716 8.932134 8.933593 ] [8.575703 8.68462 8.850494 8.883733 8.90369 8.909393 8.91768 8.936615 8.961668 8.977329 ]]
4.倒排表搜索(Inverted file with exact post-verification)
快速入门部分介绍过。
nlist = 50
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.837633 9.122337 9.217627 9.362019 9.39345 9.396795 9.401556 9.446939 9.52043 9.5279255] [9.436286 9.636714 9.707813 9.714355 9.734249 9.809814 9.87722 9.960412 9.978079 9.982276 ] [8.621426 8.658703 8.842339 8.862192 8.891519 8.937078 8.972767 8.98658 9.007745 9.088661 ] [8.211956 8.735372 8.747662 8.800873 8.917062 9.1208725 9.178852 9.215968 9.2192 9.265095 ] [8.858271 8.998444 9.041813 9.0883045 9.159481 9.169218 9.187948 9.203735 9.204121 9.256811 ] [8.434382 8.539217 8.630861 8.753672 8.768751 8.794859 8.815165 8.817884 8.8404 8.848925 ] [8.861536 8.878873 8.942963 8.944212 8.9446945 8.95914 8.980488 9.051479 9.059914 9.081419 ] [9.15522 9.423113 9.432117 9.465836 9.529045 9.554071 9.556268 9.638275 9.656209 9.69151 ] [8.743608 8.902418 9.065649 9.201052 9.223066 9.223073 9.247414 9.269661 9.288244 9.291237 ] [8.936615 9.077 9.152468 9.1537075 9.313195 9.314999 9.373196 9.400535 9.434517 9.445862 ]]
5.LSH(Locality-Sensitive Hashing (binary flat index))
nbits = 2 * d
index = faiss.IndexLSH(d, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[ 8. 10. 10. 10. 10. 10. 10. 11. 11. 11.] [ 7. 8. 9. 9. 9. 10. 10. 10. 10. 10.] [ 7. 8. 8. 9. 9. 9. 9. 9. 9. 9.] [ 9. 9. 10. 11. 12. 12. 12. 12. 12. 12.] [ 6. 6. 6. 7. 7. 8. 8. 8. 8. 8.] [ 8. 8. 8. 9. 9. 9. 9. 9. 10. 10.] [ 6. 7. 8. 8. 9. 9. 9. 9. 9. 9.] [ 9. 9. 9. 9. 9. 9. 9. 9. 9. 10.] [ 7. 8. 8. 8. 8. 8. 8. 9. 9. 9.] [ 9. 9. 9. 10. 10. 10. 10. 10. 10. 10.]]
6.SQ量化(Scalar quantizer (SQ) in flat mode)
index = faiss.IndexScalarQuantizer(d, 4)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[8.623227 8.777792 8.785317 8.828824 8.83549 8.845292 8.896896 8.914818 8.922382 8.934983 ] [9.028506 9.037546 9.099248 9.1526165 9.16542 9.19639 9.200499 9.224975 9.274046 9.3053875] [8.064029 8.21301 8.310526 8.376435 8.457833 8.462002 8.501087 8.550647 8.556992 8.624525 ] [8.19665 8.210531 8.346436 8.444769 8.452809 8.454114 8.4745245 8.496618 8.510042 8.525612 ] [8.370452 8.547959 8.704323 8.733619 8.763926 8.776738 8.829511 8.835644 8.857149 8.859046 ] [8.29591 8.432422 8.435944 8.454732 8.542395 8.565367 8.579683 8.621871 8.632034 8.644775 ] [8.609016 8.612934 8.72663 8.734133 8.758857 8.797326 8.797966 8.798654 8.815295 8.8382225] [8.378947 8.521084 8.711153 8.726161 8.748383 8.759655 8.768218 8.769182 8.792372 8.834644 ] [8.340463 8.48951 8.659344 8.664954 8.702756 8.741513 8.741941 8.768993 8.781276 8.852154 ] [8.520282 8.574987 8.683459 8.769213 8.7820425 8.85128 8.881118 8.906741 8.907756 8.924014 ]]
7.PQ量化(Product quantizer (PQ) in flat mode)
M = 8 #必须是d的因数
nbits = 6 #只能是8, 12, 16
index = faiss.IndexPQ(d, M, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[5.3184814 5.33667 5.3638916 5.366333 5.3704834 5.4000244 5.404663 5.415283 5.425659 5.427246 ] [5.6835938 5.686035 5.687134 5.7489014 5.76062 5.7731934 5.7766113 5.7875977 5.798828 5.7991943] [4.902588 5.0057373 5.0323486 5.036255 5.045044 5.048828 5.0498047 5.0499268 5.072998 5.0737305] [4.844116 4.850586 4.868042 4.8946533 4.8997803 4.8999023 4.902954 4.909546 4.9210205 4.921875 ] [5.279419 5.333252 5.3344727 5.3431396 5.35083 5.357422 5.366211 5.3862305 5.38855 5.3936768] [5.019409 5.048706 5.0942383 5.1052246 5.116455 5.157593 5.159424 5.168457 5.171875 5.194092 ] [5.0563965 5.0909424 5.1367188 5.1534424 5.1724854 5.199951 5.20105 5.2144775 5.214966 5.23938 ] [5.16333 5.173706 5.2418213 5.265259 5.265869 5.274414 5.291382 5.307495 5.309204 5.310425 ] [5.1501465 5.2508545 5.291992 5.3186035 5.3205566 5.328369 5.336548 5.3479004 5.35376 5.360962 ] [5.2751465 5.2772217 5.279663 5.3304443 5.350708 5.3571777 5.3669434 5.373047 5.373413 5.382324 ]]
8.倒排表乘积量化(IVFADC (coarse quantizer+PQ on residuals))
M = 8
nbits = 4
nlist = 50
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, M, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
[[5.1985765 5.209732 5.233874 5.237282 5.2553835 5.262968 5.270462 5.2895284 5.2908745 5.302353 ] [5.5696826 5.5942397 5.611737 5.6186624 5.619787 5.643144 5.646076 5.676093 5.682111 5.6982036] [4.7446747 4.824335 4.834736 4.844829 4.850663 4.853364 4.867393 4.873641 4.8785725 4.88787 ] [4.783175 4.797909 4.8491716 4.85687 4.857151 4.8586845 4.860058 4.866444 4.868099 4.885188 ] [5.1260395 5.134188 5.1386065 5.141901 5.1756086 5.192538 5.1938267 5.1975694 5.199704 5.2012296] [4.882325 4.900981 4.9040375 4.911916 4.916094 4.923492 4.928433 4.928472 4.937878 4.95728 ] [4.9729834 4.976016 4.984484 5.0074816 5.0200887 5.0217285 5.029479 5.029899 5.0346465 5.0349855] [5.1357193 5.147153 5.1525207 5.189519 5.217377 5.220489 5.2341766 5.239973 5.2411985 5.253551 ] [5.0623484 5.087064 5.1075807 5.109309 5.110051 5.1330123 5.1387715 5.1431603 5.151037 5.1516275] [5.12455 5.163775 5.1762547 5.185327 5.190364 5.19723 5.2099175 5.2115583 5.214532 5.2182474]]
cell-probe方法
为了加速索引过程,经常采用划分子类空间(如k-means)的方法,虽然这样无法保证最后返回的结果是完全正确的。先划分子类空间,再在部分子空间中搜索的方法,就是cell-probe方法。
具体流程为:
1)数据集空间被划分为n个部分,在k-means中,表现为n个类;
2)每个类中的向量保存在一个倒排表中,共有n个倒排表;
3)查询时,选中nprobe个倒排表;
4)将这几个倒排表中的向量与查询向量作对比。
在这种方法中,只需要排查数据库中的一部分向量,大约只有nprobe/n的数据,因为每个倒排表的长度并不一致(每个类中的向量个数不一定相等)。
cell-probe粗量化
在一些索引类型中,需要一个Flat index作为粗量化器,如IndexIVFFlat,在训练的时候会将类中心保存在Flat index中,在add和search阶段,会首先判定将其落入哪个类空间。在search阶段,nprobe参数需要调整以权衡检索精度与检索速度。
实验表明,对高维数据,需要维持比较高的nprobe数值才能保证精度。
与LSH的优劣
LSH也是一种cell-probe方法,与其相比,LSH有一下一点不足:
1)LSH需要大量的哈希方程,会带来额外的内存开销;
2)哈希函数不适合输入数据。