Inner Product Topk计算
术语
- corpus: 总候选集合(本文主要针对corpus过大的问题)
- queries: 要查询的问题
- emb:将query或corpus内数据向量化的函数
目的
从corpus中检索与每个query的emb内积最相近的topk个
Inner Product Topk
内积计算topk,主要有以下两种方式:
- faiss索引
- torch的topk计算(本文主要内容)
faiss
使用faiss可以方便的创建索引并搜索。然而,当corpu过大时,需要把corpus分块创建索引,检索时需要将索引逐个加载并搜索,并将搜索结果合并,其时间复杂度反而较高。
(faiss-gpu面对过大corpus也无法存储完整,也需要分块,且搜索时需先加载成cpu再转移到gpu上,加载过程时间消耗大,直接在cpu上计算与加载到gpu后总计算时间相当,因此可以直接使用faiss-cpu)
torch.topk
使用torch直接计算topk,面对大corpus时,转换为数据流的topk问题,因此不需要提前创建索引,可以实时的emb并计算topk(更适用于emb层的参数更新导致原索引可能失效)。
实现思路
以某个大小limit为界(取决于自己电脑的gpu的显存与计算能力,可以随意更改)
- 当corpus>limit,使用数据流方式
- 当corpus<=limit,显卡可以存储全部corpus,可以直接用torch.matmul和torch.topk计算
代码实现
# -*- coding: utf-8 -*-
# @Time : 2023/4/22 18:33
# @Author : ZiUNO
# @File : flat_ip_search.py
# @Software: PyCharm
import hashlib
import os.path
from multiprocessing.dummy import Pool
from queue import PriorityQueue
import numpy as np
import torch
from tqdm import tqdm
class FlatIPSearch:
batch_size = 64
topk_batch_size = 1024 * 10
corpus_limitation = 10 ** 4
temp_save_dir = "flat_ip_search"
dim = 768
@staticmethod
def _get_emb_file_name(texts: [str]):
md5_value = hashlib.md5("".join(texts).encode()).hexdigest()
file_path = os.path.join(FlatIPSearch.temp_save_dir, f"{
md5_value}.memmap")
return file_path
@staticmethod
def _emb_texts(texts: [str], emb_func, file_path: str = None):
file_path = file_path or FlatIPSearch._get_emb_file_name(texts)
if not os.path.exists(file_path):
batch_size = FlatIPSearch.batch_size
embedding_memmap = np.memmap(file_path, dtype='float32', mode='w+', shape=(len(texts), FlatIPSearch.dim))
for index in tqdm(range(0, len(texts), batch_size), desc="Embed"):
batch_emb = texts[index: index + batch_size]
b_size = len(batch_emb)
batch_emb = emb_func(batch_emb)
embedding_memmap[index: index + b_size] = batch_emb.detach().cpu().numpy()
del embedding_memmap
return np.memmap(file_path, dtype='float32', mode