DBSCAN on Spark我关注到的有三种实现
- https://github.com/alitouka/spark_dbscan scala写的。作者还带有两个R写的小工具,which will help you choose parameters of the DBSCAN algorithm。
- https://github.com/irvingc/dbscan-on-spark 用scala写的,据说 占用较大内存。
- An Implementation of DBSCAN on PySpark pyspark的,简单好用,但是改为经纬度坐标后没跑出来。作者对求解过程做了优化,原理是(m+1)ε ≥ d(x, c) ≥ mε then we can filter out points y and z if d(y, c) < (m-1)ε and d(z, c) > (m+2)ε ,基于此对数据做分区,减少了重复计算(不同分区说明距离太远,没必要计算判断了)。在分区边缘的点怎么处理呢?就是移动分区半个epsilon的距离,叠加所属的分区,使边缘的点也能被聚类到合适的cluster。
但是这三种我都没有使用。我用的sklearn的dbscan用做UDF,跑出来效果还不错。pyspark真简洁啊! :-)
# import findspark
# findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
import pandas as pd
import numpy as np
# import os
from sklearn.cluster import DBSCAN,KMeans
def dbscan_x(coords):
kms_per_radian = 6371.0086
# 半径200米内 20个点
epsilon = 0.2 / kms_per_radian
db = DBSCAN(eps=epsilon, min_samples=20, algorithm='ball_tree', metric='haversine').fit(np.radians(coords))
cluster_labels = db.labels_
num_clusters = len(set(cluster_labels) - set([-1]))
result = []
coords_np = np.array(coords)
kmeans = KMeans(n_clusters=1, n_init=1, max_iter=10, random_state=7)
for n in range(num_clusters):
# get center of Cluster 'n'
one_cluster = coords_np[cluster_labels == n]
kk = kmeans.fit(one_cluster)
center = kk.cluster_centers_
latlng = center[0].tolist()
result.append([n, latlng[1], latlng[0]])
return result
if __name__ == "__main__":
spark = SparkSession.builder \
.appName("stop_cluster") \
.getOrCreate()
# data_file = os.path.normpath('E:\\datas\\trj_data\\tmp_stop3_loc.csv')
# traj_schema = StructType([
# StructField("cid", StringType()),
# StructField("lng", FloatType()), StructField("lat", FloatType())
# ])
# dataDF = spark.read.csv(data_file, schema=traj_schema)
dataDF = spark.sql("select cid, lng, lat from tmp.tmp_some_data")
schema_dbs = ArrayType(StructType([
StructField("clusterid", IntegerType(), False),
StructField("lng", FloatType(), False),
StructField("lat", FloatType(), False)
]))
udf_dbscan = F.udf(lambda x: dbscan_x(x), schema_dbs)
dataDF = dataDF.withColumn('point', F.array(F.col('lat'),F.col('lng')) ) \
.groupBy('cid').agg(F.collect_list('point').alias('point_list')) \
.withColumn('cluster', udf_dbscan(F.col('point_list')))
resultDF = dataDF.withColumn('centers', F.explode('cluster')) \
.select('cid', F.col('centers').getItem('lng').alias('lng'),
F.col('centers').getItem('lat').alias('lat'),
F.col('centers').getItem('clusterid').alias('clusterid')
)
resultDF.write.mode("overwrite").format("orc").saveAsTable("tmp.tmp_cluster_ret")
resultDF.show()
spark.stop()