运筹系列64:向量检索工具

1. 概述

1.1 问题描述

底库xb中,有nb个d维向量;检索库xq中,有nq个同样维度的向量。要求对每个检索库中的向量,检索出在底库中(1)k个最近的向量(2)距离在s之内的所有向量。

1.2 可选方案

目前一直的方案:
(1)faiss,功能很全,支持倒排、PQ量化、增删改查等各种功能。支持python和c++语言。非要用java的话,github上有个faiss4java,采用c版本rpc通信的形式。需要下载faiss源码包和faiss4java源码包进行编译。建议把所有第三方库全都打成so文件,合着faiss打包成的so文件一起放到项目文件夹下编译。
(2)caiss,c++实现,使用NHSW,适用于在线系统。官方例子是用作词向量搜索,特点是支持sql和label-to-label的搜索。
(3)mivlus,java实现,需要搭建服务,比较适合线上系统使用,特点是支持增删,性能不错。

2. 暴力方法

暴力方法在faiss中称为IndexFlatL2,使用这种方式需要计算xb和xq构成的距离矩阵。这里对距离进行分解,减少计算量:
在这里插入图片描述
测试5次取平均时间,代码见下:

  • Faiss(C++):1.007s
  • python实现(使用numpy):1.133s(启用多线程后几乎无变化)
  • java实现(使用Nd4j):7.486s
  • java实现(使用javaCPP映射):2.857s

2.1 Faiss原生和numpy复现

Faiss原生方法如下:

import faiss
index = faiss.IndexFlatL2(d)  
index.add(xb)
D, I = index.search(xq, 10)

研究了一下IndexFlatL2的源码,用numpy库简单复现如下:

def rangeSearch(xb,xq,r)
    nb,d = xb.shape
    st = time()
    xbs = np.sum(np.square(xb),axis=1)
    xqs = np.sum(np.square(xq),axis=1)
    ip = 2* np.matmul(xq,xb.T)
    lables = []
    for j in (range(xq.shape[0])):
       dis = xbs + xqs[j]- ip[j]
       labels.append(np.where(dis < r))
    return labels

topK的话,只需要把np.where(dis < r)换成np.argpartition(dis, K)[:K]即可。

2.2 java调用numpy

java没有好用的向量库,测下来速度都不太行,可能就只有一个bytedeco的曲折调用numpy的库还凑合:https://github.com/bytedeco/javacpp-presets/tree/master/numpy,调用方法如下:

import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.bytedeco.cpython.*;
import org.bytedeco.numpy.*;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.numpy.global.numpy.*;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
public class rangesearch {
    public static void main(String[] args) throws IOException {
        int nb = 200000;
        int nq = 1000;
        int d = 256;
        String r = "36";
        //这里开始是读取数据
        float[] xbf = new float[nb*d];
        float[] xqf = new float[nq*d];
        BufferedReader in = new BufferedReader(new FileReader(new File("/Users/chen/xq.txt")));
        String line;  //一行数据
        int row=0;
        while((line = in.readLine()) != null){
            String[] temp = line.split(" ");
            for(int j=0;j<temp.length;j++){
                xqf[row*d+j] = Float.parseFloat(temp[j]);
            }
            row++;
        }
        in.close();
        BufferedReader in2 = new BufferedReader(new FileReader(new File("/Users/chen/xb.txt")));
        row=0;
        while((line = in2.readLine()) != null){
            String[] temp = line.split(" ");
            for(int j=0;j<temp.length;j++){
                xbf[row*d+j] = Float.parseFloat(temp[j]);
            }
            row++;
        }
        in2.close();
        // 读取数据结束
        
        System.out.println("start......");
        long start = System.currentTimeMillis();
        System.setProperty("org.bytedeco.openblas.load", "mkl");
        Py_SetPath(org.bytedeco.numpy.global.numpy.cachePackages());
        Py_Initialize();
        if (_import_array() < 0) {
            System.err.println("numpy.core.multiarray failed to import");
            PyErr_Print();
            System.exit(-1);
        }
        PyObject globals = PyModule_GetDict(PyImport_AddModule("__main__"));
        long[] dimsxb = {nb, d};
        FloatPointer dataxb = new FloatPointer(xbf);
        PyObject xb = PyArray_New(PyArray_Type(), dimsxb.length, new SizeTPointer(dimsxb),
                NPY_FLOAT, null, dataxb, 0, NPY_ARRAY_CARRAY, null);
        PyDict_SetItemString(globals, "xb", xb);
        long[] dimsxq = {nq, d};
        FloatPointer datax = new FloatPointer(xqf);
        PyObject xq = PyArray_New(PyArray_Type(), dimsxq.length, new SizeTPointer(dimsxq),
                NPY_FLOAT, null, datax, 0, NPY_ARRAY_CARRAY, null);
        PyDict_SetItemString(globals, "xq", xq);
        PyRun_StringFlags("import numpy as np;import operator;from functools import reduce;xbs = np.sum(np.square(xb),axis=1);xqs = np.sum(np.square(xq),axis=1);ip = -2* np.matmul(xq,xb.T);labels = list(range(xq.shape[0]));y = np.array(reduce(operator.add, [[-j]+list(np.where(xbs + xqs[j]+ ip[j] < "+r+")[0]) for j in labels])).astype('int32')"
                , Py_single_input, globals, globals, null);
        PyArrayObject y = new PyArrayObject(PyDict_GetItemString(globals, "y"));
        IntPointer datay = new IntPointer(PyArray_BYTES(y)).capacity(PyArray_Size(y));
        long[] dimsy = new long[PyArray_NDIM(y)];
        PyArray_DIMS(y).get(dimsy);
        long end = System.currentTimeMillis();
        System.out.println((end-start)/1000.0);
        System.out.println("y = " + IntIndexer.create(datay, dimsy));
    }
}

POM中添加

    <dependencies>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>numpy-platform</artifactId>
            <version>1.17.3-1.5.2</version>
        </dependency>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>mkl-platform</artifactId>
            <version>2019.5-1.5.2</version>
        </dependency>
    </dependencies>

2.3 Nd4j的java代码

再补充一个Nd4j库的实现,比调用numpy库要慢不少。

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
public class test2 {
    public static void main(String[] args)  throws IOException {
        int nb = 200000;
        int nq = 1000;
        int d = 256;
        float[][] xb = new float[nb][d];
        float[][] xq = new float[nq][d];
        BufferedReader in = new BufferedReader(new FileReader(new File("/Users/chen/xq.txt")));
        String line;  //一行数据
        int row=0;
        //逐行读取,并将每个数组放入到数组中
        while((line = in.readLine()) != null && row < nq){
            String[] temp = line.split(" ");
            for(int j=0;j<temp.length;j++){
                xq[row][j] = Float.parseFloat(temp[j]);
            }
            row++;
        }
        in.close();
        BufferedReader in2 = new BufferedReader(new FileReader(new File("/Users/chen/xb.txt")));
        row=0;
        //逐行读取,并将每个数组放入到数组中
        while((line = in2.readLine()) != null && row<nb){
            String[] temp = line.split(" ");
            for(int j=0;j<temp.length;j++){
                xb[row][j] = Float.parseFloat(temp[j]);
            }
            row++;
        }
        in2.close();
        System.out.println("start");
        L2Index index = new L2Index(xb);
        long start = System.currentTimeMillis();
        index.range_search(xq,6);
        long end = System.currentTimeMillis();
        System.out.println((end-start)/1000.0);
        //index.print_res();
    }
}
class L2Index {
    INDArray xb;
    ArrayList<Number> starts = new ArrayList();
    ArrayList<Number> labels = new ArrayList();
    L2Index(float[][] data) {xb = Nd4j.create(data);}
    void range_search(float[][] xqdata, float radius) {
        float radissquare = radius * radius;
        INDArray xq = Nd4j.create(xqdata);
        int nx = xb.rows();
        int ny = xq.rows();
        INDArray xbs = xb.mul(xb).sum(1);
        INDArray xqs = xq.mul(xq).sum(1);
        INDArray ip = xb.mmul(xq.transpose()).mul(-2);
        for (int j = 0; j < ny; j++) {
            starts.add(labels.size());
            INDArray res = ip.getColumn(j).add(xqs.getFloat(j)).addColumnVector(xbs).lte(radissquare);
            for (int i = 0; i < nx; i++) {
                if (res.getInt(i)==1){
                    labels.add(i);// 这里的for循环用了很长时间,如果只返回距离矩阵的话,可以减少一半时间
                }
            }
        }
    }
    void print_res() {
        System.out.println(starts);
        System.out.println(labels);
    }
}

POM添加:

    <dependencies>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>javacpp</artifactId>
            <version>1.5.2</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native</artifactId>
            <version>1.0.0-beta6</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native</artifactId>
            <version>1.0.0-beta6</version>
            <classifier>macosx-x86_64-avx2</classifier>
        </dependency>
    </dependencies>

3. Index_factory

index_factory通过字符串来创建索引,字符串包括三部分:预处理、倒排、编码。
例如index = index_factory(128, “OPQ16_64,IMI2x8,PQ8+16”): 处理128维的向量,使用OPQ来预处理数据16是OPQ内部处理的blocks大小,64为OPQ后的输出维度;使用multi-index建立65536(2^16)和倒排列表;编码采用8字节PQ和16字节refine的Re-rank方案。
构建数据:10万256维向量,查询向量1000个。
在这里插入图片描述

首先看下暴力方法:构建阶段不需用时,检索使用0.419秒,索引文件98M。 查询结果:

[ 436  769 ……  181  865]
[31.474623 33.287354…… 38.098305 38.10317 ]

3.1 预处理支持

  • PCA:PCA64表示通过PCA降维到64维(PCAMatrix实现);PCAR64表示PCA后添加一个随机旋转。
  • OPQ:OPQ16表示为数据集进行16字节编码进行预处理(OPQMatrix实现),对PQ索引很有效但是训练时也会慢一些。

3.2 倒排索引

  • IVF:IVF4096表示使用粗量化器IndexFlatL2将数据分为4096份
  • IMI:IMI2x8表示通过Mutil-index使用2x8个bits(MultiIndexQuantizer)建立2^(2*8)份的倒排索引。
  • IDMap:如果不使用倒排但需要add_with_ids,可以通过IndexIDMap来添加id

加速查找的典型方法是对数据集使用KD树进行划分:d维特征空间被切分为nlist个块,搜索时,检索离目标距离最近的nprobe个块,根据倒排列表检索nprobe个块中的所有数据。
所谓倒排,即使用子划分的查找结果来推断原数据的结果。
在这里插入图片描述

注意:nprobes默认为1。对于高维的数据,要达到较好的召回,需要的nprobes可能很大。nprobe参数始终是调整速度和结果精度之间权衡的一种方式。设置nprobe = nlist给出与强力搜索相同的结果(但较慢)。
这便是IndexIVFFlat,它需要另一个索引(quantizer)来记录倒排列表。metric目前有两种,METRIC_L2和METRIC_INNER_PRODUCT。
注意IndexIVFKmeans 和 IndexIVFSphericalKmeans 不是对象而是方法,它们可以返回IndexIVFFlat对象。
下面的例子构建索引使用0.581秒,检索的话nprobe=1时使用0.022秒,nprobe = 10时使用0.122秒,索引文件99M

import faiss
nlist = 100
quantizer = faiss.IndexFlatL2(d)    # 内部的索引方式依然不变
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
index.train(xb)
index.add(xb)
index.nprobe = 10

nprobe=1时有些误差。

3.3 量化

  • Flat:存储原始向量,通过IndexFlat或IndexIVFFlat实现
  • PQ:PQ16使用16个字节编码向量,通过IndexPQ或IndexIVFPQ实现
  • PQ8+16:表示通过8字节来进行PQ,16个字节对第一级别量化的误差再做PQ,通过IndexIVFPQR实现

PQ的原理如下:首先对维度进行拆分,然后对子维度进行k-means聚类,这里聚类结果用量化编码表示。最后再把聚类结果合起来。
PQ算法虽然快,但是只适用于topK计算,对于range search明显不适合。
在这里插入图片描述

下面的例子构建索引使用10.747秒,检索使用0.21秒,索引文件1.8M

m = 16                                   # number of subquantizers
n_bits = 8                               # bits allocated per subquantizer
index = faiss.IndexPQ (d, m, n_bits)        # Create the index
index.train(xb)                            # Training
index.add(xb)  

结果为:

[ 229 1172  ……  964  353]
[24.422462 25.364008 …… 27.67841  27.715546]

不论是距离还是索引,都有很大的改变

3.4 常见索引方法列表

下面列举一下重要的索引方法以及对应的index_factory:
在这里插入图片描述
这里介绍一下IVFPQ的原理,首先使用kd树粗量化进行倒排,残差则用量化:
在这里插入图片描述

构建索引使用5.604秒,检索使用0.016秒,索引文件1.9M

nlist = 100
m = 8
st0 =time()
quantizer = faiss.IndexFlatL2(d)    # 内部的索引方式依然不变
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)# 每个向量都被编码为8个字节大小
index.train(xb)
index.add(xb)

结果为:

[145  429……82  431]
[21.083157 21.362541 ……24.43357  24.436153]

速度非常快,但是好像距离差的更大了。

4. 特殊方法

4.1 IDMap

可以直接对id进行搜索,调用IndexIDMap方法构建新的索引,并使用add_with_ids方法将原数据与id进行关联

import faiss
index = faiss.IndexFlatL2(d)   # build the index
ids = np.arange(100000, 200000)  #id设定为6位数整数,默认id从0开始,这里我们将其设置从100000开始
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(xb, ids)
print(index2.ntotal)
k = 4   # we want to see 4 nearest neighbors
D, I = index2.search(xq, k) # sanity check

IndexIVF可以使用Direct Map:

  • DirectMap.NoMap: no mapping is stored, reconstruction is not possible (default).
  • DirectMap.Array: the direct map is an array. The indices are assumed to be sequential, which rules out add_with_ids
  • DirectMap.Hashtable: the direct map is a hashtable. Indices can be arbitrary and add_with_ids works (provided indices are distinct).

4.2 range search

支持IndexFlat, IndexIVFFlat, IndexScalarQuantizer, IndexIVFScalarQuantizer
Python返回的是一个三维元组,分别是:起始和终止位置、距离列表、下标列表。
在这里插入图片描述
获取第point_id对应的下标列表的方法如下:

res[2][res[0][point_id]:res[0][point_id+1]]

4.3 重建与删除

使用reconstruct_n方法可以重建ids列表中的索引,其他索引不变。
使用remove_ids方法可以删除索引。对IndexFlat, IndexIVFFlat, IDMap有效。

4.4 Annoy

Annoy通过随机挑选两个点,并使用垂直于这个点的等距离超平面将集合划分为两部分。
依此类推,直到每个集合最多剩余k个点:
在这里插入图片描述
下面是参考的python code,使用pypi上的annoy库,有两个参数可以用来调节Annoy 树的数量n_trees和搜索期间检查的节点数量search_k。n_trees在构建时提供,并影响构建时间和索引大小。 较大的值将给出更准确的结果,但更大的索引。search_k在运行时提供,并影响搜索性能。 较大的值将给出更准确的结果,但将需要更长的时间返回。

from annoy import AnnoyIndex
t = AnnoyIndex(d, 'euclidean')
for i in range(nb):
    t.add_item(i,xb[i])
t.build(20)
res = []
for x in xq:
    res.append(t.get_nns_by_vector(x,100))

4.5 HNSW

Hierarchcal Navigable Small World graphs,层次小世界导航
首先是构造近邻关系图,Delaunary三角剖分是一个直观方案。 但NSW没有采用德劳内三角剖分法来构成德劳内三角网图,原因之一是德劳内三角剖分构图算法时间复杂度太高,换句话说,构图太耗时。原因之二是德劳内三角形的查找效率并不一定最高,如果初始点和查找点距离很远的话我们需要进行多次跳转才能查到其临近点,需要“高速公路”机制(Expressway mechanism, 这里指部分远点之间拥有线段连接,以便于快速查找)。在理想状态下,我们的算法不仅要满足上面三条需求,还要算法复杂度低,同时配有高速公路机制的构图法。
NSW朴素构图算法在这里:向图中逐个插入点,插图一个全新点时,通过朴素想法中的朴素查找法(通过计算“友点”和待插入点的距离来判断下一个进入点是哪个点)查找到与这个全新点最近的m个点(m由用户设置),连接全新点到m个点的连线。
至于Hierarchcal,的是跳表结构,每个点都有50%的概率进入上一层,这样就构成了逐级稀疏的结构。
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值