在大数据场景下,高效地进行近似最近邻搜索(Approximate Nearest Neighbors, ANN)是许多应用的关键,如推荐系统、图像检索等。传统的单机版 HNSWlib 在处理大规模数据时速度较慢,因此我们尝试采用分布式解决方案 HNSWlib-PySpark 进行召回实验。
背景
HNSW(Hierarchical Navigable Small World)是一种高效的 ANN 算法,通过构建层次化的图结构来加速搜索过程。HNSWlib 是其实现库,但在单机环境下处理大规模数据时性能受限。HNSWlib-PySpark 将 HNSW 算法与 PySpark 集成,利用分布式计算的优势,可以更高效地处理海量数据。
HNSWlib-PySpark 测试
安装
首先,确保安装了 HNSWlib-PySpark:
pip install pyspark-hnsw --upgrade
在 PySpark 调度时,添加以下配置:
conf.spark.jars.packages 'com.github.jelmerk:hnswlib-spark_2.3_2.11:1.1.0'
测试代码
下面是完整的测试代码:
import os
import argparse
import random
import logging
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, column, encode
from pyspark.sql.types import *
from datetime import datetime, timedelta
import requests as req
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline
from pyspark_hnsw.knn import HnswSimilarity
from pyspark_hnsw.evaluation import KnnSimilarityEvaluator
from pyspark_hnsw.knn import *
from pyspark_hnsw.linalg import Normalizer
from pyspark_hnsw.conversion import VectorConverter
from pyspark.ml.linalg import Vectors
hadoop = os.path.join(os.environ['HADOOP_COMMON_HOME'], 'bin/hadoop')
def init_spark():
spark = SparkSession.builder \
.config("spark.sql.caseSensitive", "false") \
.config("spark.shuffle.spill", "true") \
.config("spark.shuffle.spill.compress", "true") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.config("metastore.catalog.default", "hive") \
.config("spark.sql.hive.convertMetastoreOrc", "true") \
.config("spark.kryoserializer.buffer.max", "1024m") \
.config("spark.kryoserializer.buffer", "64m") \
.config("spark.driver.maxResultSize","4g") \
.config("spark.sql.broadcastTimeout", "36000") \
.enableHiveSupport() \
.getOrCreate()
return spark
def system_command(command):
code = os.system(command)
if code != 0:
logging.error(f"Command: ({command}) excute failed.")
else:
logging.info(f"Command: ({command}) excute succeed.")
if __name__ == "__main__":
spark = init_spark()
# 查询数据
df = spark.sql(
f"""select user_id_zm, user_embedding from algo.dssm_user_embedding where pt='2025-05-18' """
)
# 转换数据格式
df_user_id = df.rdd.map(lambda row: row.user_id_zm)
df_embedding = df.rdd.map(lambda row: Vectors.dense(row.user_embedding))
new_df = df_user_id.zip(df_embedding).toDF(schema=['user_id_zm', 'user_embedding'])
# 查看数据
new_df.show()
# 数据预处理
converter = VectorConverter(inputCol='user_embedding', outputCol='features')
normalizer = Normalizer(inputCol='features', outputCol='normalized_features')
# HNSW相似度计算
hnsw = HnswSimilarity(identifierCol='user_id_zm', queryIdentifierCol='user_id_zm', featuresCol='normalized_features', distanceFunction='inner-product', m=48, ef=15, k=10,
efConstruction=200, numPartitions=2, excludeSelf=True, similarityThreshold=0.4, predictionCol='approximate')
# 暴力计算相似度
brute_force = BruteForceSimilarity(identifierCol='user_id_zm', queryIdentifierCol='user_id_zm', featuresCol='normalized_features', distanceFunction='inner-product',
k=10, numPartitions=2, excludeSelf=True, similarityThreshold=0.4, predictionCol='exact')
# 构建 Pipeline
pipeline = Pipeline(stages=[converter, normalizer, hnsw, brute_force])
model = pipeline.fit(new_df)
# 对部分数据进行查询
query_items = new_df.sample(0.01)
output = model.transform(query_items)
# 评估结果
evaluator = KnnSimilarityEvaluator(approximateNeighborsCol='approximate', exactNeighborsCol='exact')
accuracy = evaluator.evaluate(output)
print("accuracy: ", accuracy)
# 停止 SparkSession
spark.stop()
del spark
实验结果
在测试中,使用 HNSWlib-PySpark 进行召回实验,与暴力计算相比,召回率在 0.8 ~ 0.9 之间,这个结果在大规模数据场景下还算可以接受。HNSWlib-PySpark 的优势在于其分布式架构,能够有效处理海量数据,提高召回效率。