faiss ivfpq索引构建

假设已有训练好的向量值,构建索引(nlist和随机样本按需选取)

import numpy as np
import faiss
import pickle
from tqdm import tqdm
import time
import os
import random

# 读取嵌入向量并保留对应关系
def read_embeddings(directory, batch_size=10000):
    for root, dirs, files in os.walk(directory):
        for file in files:
            cur_file = os.path.join(root, file)
            print("Loading file >>>", cur_file)
            lines=[]
            with open(cur_file, 'r') as file:
                # for i in range(100000):
                #     line = file.readline()
                #     lines.append(line)
                lines = file.readlines()

            batch_ids = []
            batch_embeddings = []
            for i, line in enumerate(tqdm(lines, ncols=100)):
                if i > 0 and i % batch_size == 0:
                    yield np.array(batch_embeddings, dtype='float32'), batch_ids
                    batch_ids = []
                    batch_embeddings = []
                parts = line.strip().split('\t')
                identifier = parts[0]
                vector_str = parts[1]
                vector = np.fromstring(vector_str[1:-1], sep=',')
                batch_ids.append(identifier)
                batch_embeddings.append(vector)
            if batch_embeddings:
                yield np.array(batch_embeddings, dtype='float32'), batch_ids

try:
    # 读取嵌入向量
    directory_path = './data'
    
    embeddings_batches = []
    ids = []
    
    for embeddings_batch, ids_batch in read_embeddings(directory_path):
        embeddings_batches.append(embeddings_batch)
        ids.extend(ids_batch)
    
    print("Data loading complete, start building the index")
    N = sum(batch.shape[0] for batch in embeddings_batches)
    D = embeddings_batches[0].shape[1]
    print(f"Embeddings shape: {N}x{D}")
    
    nlist = 100000
    m = 32
    n_bits = 8
    
    quantizer = faiss.IndexFlatL2(D)
    
    index = faiss.IndexIVFPQ(quantizer, D, nlist, m, n_bits)
    
    print("Start training the index...")
    all_embeddings=np.vstack(embeddings_batches)
    train_start = time.time()
    # 随机选择子样本进行训练
    sample_size = min(1000000, N)  # 取最大 100,000 个样本
    sample_indices = random.sample(range(N), sample_size)
    sample_embeddings = all_embeddings[sample_indices]
    print("随机选取样本训练")
    
    index.train(sample_embeddings)
    train_end = time.time()
    print(f"Training completed, time taken: {(train_end - train_start) / 3600:.2f} hours")
    
    # 分批添加嵌入到索引中
    print("Start adding embeddings to the index...")
    add_start = time.time()
    flag=0
    for embeddings_batch in embeddings_batches:
        flag+=1
        if flag%100==0:
            print(flag)
        index.add(embeddings_batch)
    add_end = time.time()
    print(f"Adding embeddings completed, time taken: {(add_end - add_start) / 3600:.2f} hours")
    
    print("Start saving the index...")
    save_start = time.time()
    faiss.write_index(index, "index_ivfpq_1b.faiss")
    save_end = time.time()
    print(f"Index saved, time taken: {(save_end - save_start) / 3600:.2f} hours")
    
    index_to_identifier = {"faiss_v1_"+str(i): identifier for i, identifier in enumerate(ids)}
    
    with open('index_to_identifier_1b.pkl', 'wb') as f:
        pickle.dump(index_to_identifier, f)
    print("Index to identifier mapping saved.")
except Exception as e:
    print("Error occurred during index construction:", str(e))

向量查询

import time
import numpy as np
import faiss
import pickle

# 加载索引
index = faiss.read_index("index_ivfpq_1b.faiss")

# 加载标识符对应关系
with open('index_to_identifier_1b.pkl', 'rb') as f:
    index_to_identifier = pickle.load(f)
# 查询簇中心数量
index.nprobe = 100
# 限制使用的 CPU 核数
faiss.omp_set_num_threads(4)  # 设置使用的线程数,可以根据你的实际需求进行调整

# 直接定义查询向量和标识符
query_embedding = np.array([[-0.01962059736251831, 0.11334816366434097, -0.09471801668405533, 0.0641612783074379, 0.016695162281394005, 0.03470868244767189, 0.059329044073820114, -0.024794576689600945, -0.012960868887603283, -0.0744692012667656, -0.07942882925271988, 0.19218777120113373, 0.14370097219944, 0.11092912405729294, -0.06869585067033768, 0.08476870507001877, 0.10311301797628403, -0.09529904276132584, 0.11519007384777069, 0.07435101270675659, -0.07236043363809586, 0.010397439822554588, -0.06027359142899513, -0.08405963331460953, 0.031723152846097946, -0.1143064945936203, 0.18072178959846497, 0.07466364651918411, 0.10553380101919174, -0.10898686945438385, -0.19313931465148926, 0.15539272129535675, -0.11933872103691101, -0.13383139669895172, 0.0754752978682518, 0.04579591378569603, 0.07465954124927521, -0.0241111870855093, -0.06121497601270676, -0.10494254529476166, -0.01837378740310669, 0.1292468160390854, -0.0056768800131976604, 0.06756076216697693, -0.08115670830011368, 0.09304261207580566, 0.06945249438285828, -0.057487890124320984, 0.07290451973676682, -0.01492359396070242, 0.14174117147922516, 0.0752357617020607, 0.014304161071777344, -0.0023451936431229115, 0.08765687793493271, 0.10875667631626129, 0.1779395043849945, -0.04857892543077469, 0.054570272564888, -0.15957848727703094, 0.008002348244190216, 0.03754493221640587, 0.07620261609554291, 0.01903180405497551, 0.14646433293819427, -0.07392526417970657, 0.02997334860265255, -0.04795815050601959, 0.039741817861795425, -0.06323029100894928, -0.0361541248857975, 0.1155063807964325, -0.03679197281599045, 0.08797583729028702, -0.068557009100914, -0.14507029950618744, 0.06844533234834671, 0.09862343966960907, 0.012137680314481258, -0.012296526692807674, 0.05485907569527626, 0.08134670555591583, 0.06546603888273239, 0.10151205956935883, -0.1254400908946991, 0.06678715348243713, 0.015612985007464886, 0.03761797398328781, 0.11426421254873276, -0.10608682036399841, 0.0054876371286809444, -0.13291053473949432, -0.1383194625377655, -0.060186877846717834, 0.040753982961177826, 0.025832200422883034, 0.06087275967001915, 0.07576646655797958, -0.025103572756052017, 0.0819762796163559, 0.06338494271039963, 0.09223338961601257, 0.11740309000015259, 0.16588829457759857, 0.0016070181736722589, -0.11642675846815109, 0.06580012291669846, 0.07179497182369232, -0.11596480011940002, 0.05284847319126129, 0.018308958038687706, 0.2823641896247864, 0.0026317911688238382, -0.013333271257579327, -0.07727757096290588, -0.06593139469623566, 0.06467396765947342, 0.04348631948232651, 0.02083323895931244, -0.004868550691753626, -0.06408777832984924, -0.12004149705171585, 0.09156100451946259, 0.04209277778863907, 0.04682828485965729, 0.06600149720907211, 0.014075364917516708, 0.02114858292043209]], dtype='float32')


query_id = "龙血王手串价格及图片"  # 这里添加你的查询向量对应的标识符

s = time.time()

# 确定查询向量的数量和维度
num_queries, D = query_embedding.shape

# 进行搜索
k = 10  # 返回前 k 个最近邻
distances, indices = index.search(query_embedding, k)

# 显示查询结果
print(f"Query ID: {query_id}")
print("Top k results:")
for j in range(k):
    idx = indices[0, j]
    distance = distances[0, j]
    if idx != -1:  # 有效索引
        idx="faiss_v1_"+str(idx)
        identifier = index_to_identifier.get(idx, "Unknown")
        print(f"  {j+1}. ID: {identifier}, Distance: {distance}")
    else:
        print(f"  {j+1}. No result")

e = time.time()
print(f"Time taken for search: {e - s} seconds")
  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值