最近着手的一个项目需要在Spark环境下使用DBSCAN算法,遗憾的是Spark MLlib中并没有提供该算法。调研了一些相关的文章,有些方案是将样本点按照空间位置进行分区,并在每个空间分区中分别跑DBSCAN,但是这种方案容易遇到数据倾斜的问题,并且在分区的边界的结果很有可能是错误的。
经过与一些小伙伴的交流,通过几天的探索尝试,最终在Spark上手工实现了分布式的DBSCAN算法,
经过校验结果和Sklearn单机结果完全一致,并且性能也达到了工业级水平
。
通过该算法的实现,加深了对Spark的理解,用到了分批次广播和分区迭代计算等技巧,感觉自己还是棒棒哒,特意分享出来供有需要的小伙伴们参考。
一,总体思路
DBSCAN算法的分布式实现需要解决以下一些主要的问题。
1,如何计算样本点中两两之间的距离?
在单机环境下,计算样本点两两之间的距离比较简单,是一个双重遍历的过程。
为了减少计算量,可以用空间索引如Rtree进行加速。
在分布式环境,样本点分布在不同的分区,难以在不同的分区之间直接进行双重遍历。
为了解决这个问题,我的方案是将样本点不同的分区分成多个批次拉到Driver端, 然后依次广播到各个excutor分别计算距离,将最终结果union,从而间接实现双重遍历。
为了减少计算量,广播前对拉到Driver端的数据构建空间索引Rtree进行加速。
2,如何构造临时聚类簇?
这个问题不难,单机环境和分布式环境的实现差不多。
都是通过group的方式统计每个样本点周边邻域半径R内的样本点数量,
并记录它们的id,如果这些样本点数量超过minpoints则构造临时聚类簇,并维护核心点列表。
3,如何合并相连的临时聚类簇得到聚类簇?
这个是分布式实现中最最核心的问题。
在单机环境下,标准做法是对每一个临时聚类簇,判断其中的样本点是否在核心点列表,如果是,则将该样本点所在的临时聚类簇与当前临时聚类簇合并。
并在核心点列表中删除该样本点。
重复此过程,直到当前临时聚类簇中所有的点都不在核心点列表。
在分布式环境下,临时聚类簇分布在不同的分区,无法直接扫描全局核心点列表进行临时聚类簇的合并。
我的方案是先在每一个分区内部对各个临时聚类簇进行合并,然后缩小分区数量重新分区,再在各个分区内部对每个临时聚类簇进行合并。
不断重复这个过程,最终将所有的临时聚类簇都划分到一个分区,完成对全部临时聚类簇的合并。
为了降低最后一个分区的存储压力,我采用了不同于标准的临时聚类簇的合并算法。
对每个临时聚类簇只关注其中的核心点id,而不关注非核心点id,以减少存储压力。
合并时将有共同核心点id的临时聚类簇合并。
为了加快临时聚类的合并过程,分区时并非随机分区,而是以每个临时聚类簇的核心点id中的最小值min_core_id作为分区的Hash参数,具有共同核心点id的临时聚类簇有更大的概率被划分到同一个分区,从而加快了合并过程。
二,核心代码
import org.apache.spark.sql.SparkSession
val spark = SparkSession
.builder()
.appName("dbscan")
.getOrCreate()
val sc = spark.sparkContext
import spark.implicits._
1,寻找核心点形成临时聚类簇。
该步骤一般要采用空间索引 + 广播的方法,此处从略,假定已经得到了临时聚类簇。
//rdd_core的每一行代表一个临时聚类簇:(min_core_id, core_id_set)
//core_id_set为临时聚类簇所有核心点的编号,min_core_id为这些编号中取值最小的编号
var rdd_core = sc.parallelize(List((1L,Set(1L,2L)),(2L,Set(2L,3L,4L)),
(6L,Set(6L,8L,9L)),(4L,Set(4L,5L)),
(9L,Set(9L,10L,11L)),(15L,Set(15L,17L)),
(10L,Set(10L,11L,18L))))
rdd_core.collect.foreach(println)
2,合并临时聚类簇得到聚类簇。
import scala.collection.mutable.ListBuffer
import org.apache.spark.HashPartitioner
//定义合并函数:将有共同核心点的临时聚类簇合并
val mergeSets = (set_list: ListBuffer[Set[Long]]) =>{
var result = ListBuffer[Set[Long]]()
while (set_list.size>0){
var cur_set = set_list.remove(0)
var intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)
while(intersect_idxs.size>0){
for(idx<-intersect_idxs){
cur_set = cur_set|set_list(idx)
}
for(idx<-intersect_idxs){
set_list.remove(idx)
}
intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)
}
result = result:+cur_set
}
result
}
///对rdd_core分区后在每个分区合并,不断将分区数量减少,最终合并到一个分区
//如果数据规模十分大,难以合并到一个分区,也可以最终合并到多个分区,得到近似结果。
//rdd: (min_core_id,core_id_set)
def mergeRDD(rdd: org.apache.spark.rdd.RDD[(Long,Set[Long])], partition_cnt:Int):
org.apache.spark.rdd.RDD[(Long,Set[Long])] = {
val rdd_merged = rdd.partitionBy(new HashPartitioner(partition_cnt))
.mapPartitions(iter => {
val buffer = ListBuffer[Set[Long]]()
for(t<-iter){
val core_id_set:Set[Long] = t._2
buffer.append(core_id_set)
}
val merged_buffer = mergeSets(buffer)
var result = List[(Long,Set[Long])]()
for(core_id_set<-merged_buffer){
val min_core_id = core_id_set.min
result = result:+(min_core_id,core_id_set)
}
result.iterator
})
rdd_merged
}
//分区迭代计算,可以根据需要调整迭代次数和分区数量
rdd_core = mergeRDD(rdd_core,8)
rdd_core = mergeRDD(rdd_core,4)
rdd_core = mergeRDD(rdd_core,1)
rdd_core.collect.foreach(println)
三,完整范例
完整范例还包括临时聚类簇的生成,以及最终聚类信息的整理。鉴于该部分代码较为冗长,在当前文章中不展示全部代码,仅说明最终结果。