假设已有训练好的向量值,构建索引(nlist和随机样本按需选取)
import numpy as np
import faiss
import pickle
from tqdm import tqdm
import time
import os
import random
def read_embeddings(directory, batch_size=10000):
for root, dirs, files in os.walk(directory):
for file in files:
cur_file = os.path.join(root, file)
print("Loading file >>>", cur_file)
lines=[]
with open(cur_file, 'r') as file:
lines = file.readlines()
batch_ids = []
batch_embeddings = []
for i, line in enumerate(tqdm(lines, ncols=100)):
if i > 0 and i % batch_size == 0:
yield np.array(batch_embeddings, dtype='float32'), batch_ids
batch_ids = []
batch_embeddings = []
parts = line.strip().split('\t')
identifier = parts[0]
vector_str = parts[1]
vector = np.fromstring(vector_str[1:-1], sep=',')
batch_ids.append(identifier)
batch_embeddings.append(vector)
if batch_embeddings:
yield np.array(batch_embeddings, dtype='float32'), batch_ids
try:
directory_path = './data'
embeddings_batches = []
ids = []
for embeddings_batch, ids_batch in read_embeddings(directory_path):
embeddings_batches.append(embeddings_batch)
ids.extend(ids_batch)
print("Data loading complete, start building the index")
N = sum(batch.shape[0] for batch in embeddings_batches)
D = embeddings_batches[0].shape[1]
print(f"Embeddings shape: {N}x{D}")
nlist = 100000
m = 32
n_bits = 8
quantizer = faiss.IndexFlatL2(D)
index = faiss.IndexIVFPQ(quantizer, D, nlist, m, n_bits)
print("Start training the index...")
all_embeddings=np.vstack(embeddings_batches)
train_start = time.time()
sample_size = min(1000000, N)
sample_indices = random.sample(range(N), sample_size)
sample_embeddings = all_embeddings[sample_indices]
print("随机选取样本训练")
index.train(sample_embeddings)
train_end = time.time()
print(f"Training completed, time taken: {(train_end - train_start) / 3600:.2f} hours")
print("Start adding embeddings to the index...")
add_start = time.time()
flag=0
for embeddings_batch in embeddings_batches:
flag+=1
if flag%100==0:
print(flag)
index.add(embeddings_batch)
add_end = time.time()
print(f"Adding embeddings completed, time taken: {(add_end - add_start) / 3600:.2f} hours")
print("Start saving the index...")
save_start = time.time()
faiss.write_index(index, "index_ivfpq_1b.faiss")
save_end = time.time()
print(f"Index saved, time taken: {(save_end - save_start) / 3600:.2f} hours")
index_to_identifier = {"faiss_v1_"+str(i): identifier for i, identifier in enumerate(ids)}
with open('index_to_identifier_1b.pkl', 'wb') as f:
pickle.dump(index_to_identifier, f)
print("Index to identifier mapping saved.")
except Exception as e:
print("Error occurred during index construction:", str(e))
向量查询
import time
import numpy as np
import faiss
import pickle
index = faiss.read_index("index_ivfpq_1b.faiss")
with open('index_to_identifier_1b.pkl', 'rb') as f:
index_to_identifier = pickle.load(f)
index.nprobe = 100
faiss.omp_set_num_threads(4)
query_embedding = np.array([[-0.01962059736251831, 0.11334816366434097, -0.09471801668405533, 0.0641612783074379, 0.016695162281394005, 0.03470868244767189, 0.059329044073820114, -0.024794576689600945, -0.012960868887603283, -0.0744692012667656, -0.07942882925271988, 0.19218777120113373, 0.14370097219944, 0.11092912405729294, -0.06869585067033768, 0.08476870507001877, 0.10311301797628403, -0.09529904276132584, 0.11519007384777069, 0.07435101270675659, -0.07236043363809586, 0.010397439822554588, -0.06027359142899513, -0.08405963331460953, 0.031723152846097946, -0.1143064945936203, 0.18072178959846497, 0.07466364651918411, 0.10553380101919174, -0.10898686945438385, -0.19313931465148926, 0.15539272129535675, -0.11933872103691101, -0.13383139669895172, 0.0754752978682518, 0.04579591378569603, 0.07465954124927521, -0.0241111870855093, -0.06121497601270676, -0.10494254529476166, -0.01837378740310669, 0.1292468160390854, -0.0056768800131976604, 0.06756076216697693, -0.08115670830011368, 0.09304261207580566, 0.06945249438285828, -0.057487890124320984, 0.07290451973676682, -0.01492359396070242, 0.14174117147922516, 0.0752357617020607, 0.014304161071777344, -0.0023451936431229115, 0.08765687793493271, 0.10875667631626129, 0.1779395043849945, -0.04857892543077469, 0.054570272564888, -0.15957848727703094, 0.008002348244190216, 0.03754493221640587, 0.07620261609554291, 0.01903180405497551, 0.14646433293819427, -0.07392526417970657, 0.02997334860265255, -0.04795815050601959, 0.039741817861795425, -0.06323029100894928, -0.0361541248857975, 0.1155063807964325, -0.03679197281599045, 0.08797583729028702, -0.068557009100914, -0.14507029950618744, 0.06844533234834671, 0.09862343966960907, 0.012137680314481258, -0.012296526692807674, 0.05485907569527626, 0.08134670555591583, 0.06546603888273239, 0.10151205956935883, -0.1254400908946991, 0.06678715348243713, 0.015612985007464886, 0.03761797398328781, 0.11426421254873276, -0.10608682036399841, 0.0054876371286809444, -0.13291053473949432, -0.1383194625377655, -0.060186877846717834, 0.040753982961177826, 0.025832200422883034, 0.06087275967001915, 0.07576646655797958, -0.025103572756052017, 0.0819762796163559, 0.06338494271039963, 0.09223338961601257, 0.11740309000015259, 0.16588829457759857, 0.0016070181736722589, -0.11642675846815109, 0.06580012291669846, 0.07179497182369232, -0.11596480011940002, 0.05284847319126129, 0.018308958038687706, 0.2823641896247864, 0.0026317911688238382, -0.013333271257579327, -0.07727757096290588, -0.06593139469623566, 0.06467396765947342, 0.04348631948232651, 0.02083323895931244, -0.004868550691753626, -0.06408777832984924, -0.12004149705171585, 0.09156100451946259, 0.04209277778863907, 0.04682828485965729, 0.06600149720907211, 0.014075364917516708, 0.02114858292043209]], dtype='float32')
query_id = "龙血王手串价格及图片"
s = time.time()
num_queries, D = query_embedding.shape
k = 10
distances, indices = index.search(query_embedding, k)
print(f"Query ID: {query_id}")
print("Top k results:")
for j in range(k):
idx = indices[0, j]
distance = distances[0, j]
if idx != -1:
idx="faiss_v1_"+str(idx)
identifier = index_to_identifier.get(idx, "Unknown")
print(f" {j+1}. ID: {identifier}, Distance: {distance}")
else:
print(f" {j+1}. No result")
e = time.time()
print(f"Time taken for search: {e - s} seconds")