Spark跑「DBSCAN」算法,工业级代码长啥样?

最近着手的一个项目需要在Spark环境下使用DBSCAN算法,遗憾的是Spark MLlib中并没有提供该算法。调研了一些相关的文章,有些方案是将样本点按照空间位置进行分区,并在每个空间分区中分别跑DBSCAN,但是这种方案容易遇到数据倾斜的问题,并且在分区的边界的结果很有可能是错误的。

经过与一些小伙伴的交流,通过几天的探索尝试,最终在Spark上手工实现了分布式的DBSCAN算法, 经过校验结果和Sklearn单机结果完全一致,并且性能也达到了工业级水平 。

通过该算法的实现,加深了对Spark的理解,用到了分批次广播和分区迭代计算等技巧,感觉自己还是棒棒哒,特意分享出来供有需要的小伙伴们参考。

一,总体思路

DBSCAN算法的分布式实现需要解决以下一些主要的问题。
640?wx_fmt=png
1,如何计算样本点中两两之间的距离?
在单机环境下,计算样本点两两之间的距离比较简单,是一个双重遍历的过程。 为了减少计算量,可以用空间索引如Rtree进行加速。

在分布式环境,样本点分布在不同的分区,难以在不同的分区之间直接进行双重遍历。 为了解决这个问题,我的方案是将样本点不同的分区分成多个批次拉到Driver端, 然后依次广播到各个excutor分别计算距离,将最终结果union,从而间接实现双重遍历。
为了减少计算量,广播前对拉到Driver端的数据构建空间索引Rtree进行加速。

2,如何构造临时聚类簇?
这个问题不难,单机环境和分布式环境的实现差不多。
都是通过group的方式统计每个样本点周边邻域半径R内的样本点数量,
并记录它们的id,如果这些样本点数量超过minpoints则构造临时聚类簇,并维护核心点列表。

3,如何合并相连的临时聚类簇得到聚类簇?
这个是分布式实现中最最核心的问题。
在单机环境下,标准做法是对每一个临时聚类簇,判断其中的样本点是否在核心点列表,如果是,则将该样本点所在的临时聚类簇与当前临时聚类簇合并。 并在核心点列表中删除该样本点。 重复此过程,直到当前临时聚类簇中所有的点都不在核心点列表。
在分布式环境下,临时聚类簇分布在不同的分区,无法直接扫描全局核心点列表进行临时聚类簇的合并。 我的方案是先在每一个分区内部对各个临时聚类簇进行合并,然后缩小分区数量重新分区,再在各个分区内部对每个临时聚类簇进行合并。 不断重复这个过程,最终将所有的临时聚类簇都划分到一个分区,完成对全部临时聚类簇的合并。
为了降低最后一个分区的存储压力,我采用了不同于标准的临时聚类簇的合并算法。 对每个临时聚类簇只关注其中的核心点id,而不关注非核心点id,以减少存储压力。 合并时将有共同核心点id的临时聚类簇合并。

为了加快临时聚类的合并过程,分区时并非随机分区,而是以每个临时聚类簇的核心点id中的最小值min_core_id作为分区的Hash参数,具有共同核心点id的临时聚类簇有更大的概率被划分到同一个分区,从而加快了合并过程。

二,核心代码

import org.apache.spark.sql.SparkSession	

	
val spark = SparkSession	
.builder()	
.appName("dbscan")	
.getOrCreate()	

	
val sc = spark.sparkContext	
import spark.implicits._

1,寻找核心点形成临时聚类簇。

该步骤一般要采用空间索引 + 广播的方法,此处从略,假定已经得到了临时聚类簇。

//rdd_core的每一行代表一个临时聚类簇:(min_core_id, core_id_set)	
//core_id_set为临时聚类簇所有核心点的编号,min_core_id为这些编号中取值最小的编号	
var rdd_core = sc.parallelize(List((1L,Set(1L,2L)),(2L,Set(2L,3L,4L)),	
                                       (6L,Set(6L,8L,9L)),(4L,Set(4L,5L)),	
                                       (9L,Set(9L,10L,11L)),(15L,Set(15L,17L)),	
                                       (10L,Set(10L,11L,18L))))	
rdd_core.collect.foreach(println)

640?wx_fmt=png

2,合并临时聚类簇得到聚类簇。

import scala.collection.mutable.ListBuffer	
import org.apache.spark.HashPartitioner	

	
//定义合并函数:将有共同核心点的临时聚类簇合并	
val mergeSets = (set_list: ListBuffer[Set[Long]]) =>{	
var result = ListBuffer[Set[Long]]()	
while (set_list.size>0){	
var cur_set = set_list.remove(0)	
var intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)	
while(intersect_idxs.size>0){	
for(idx<-intersect_idxs){	
        cur_set = cur_set|set_list(idx)	
      }	
for(idx<-intersect_idxs){	
        set_list.remove(idx)	
      }	
      intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)	
    }	
result = result:+cur_set	
  }	
result	
}	

	
///对rdd_core分区后在每个分区合并,不断将分区数量减少,最终合并到一个分区	
//如果数据规模十分大,难以合并到一个分区,也可以最终合并到多个分区,得到近似结果。	
//rdd: (min_core_id,core_id_set)	

	
def mergeRDD(rdd: org.apache.spark.rdd.RDD[(Long,Set[Long])], partition_cnt:Int):	
org.apache.spark.rdd.RDD[(Long,Set[Long])] = {	
  val rdd_merged =  rdd.partitionBy(new HashPartitioner(partition_cnt))	
    .mapPartitions(iter => {	
      val buffer = ListBuffer[Set[Long]]()	
for(t<-iter){	
        val core_id_set:Set[Long] = t._2	
        buffer.append(core_id_set)	
      }	
      val merged_buffer = mergeSets(buffer)	
var result = List[(Long,Set[Long])]()	
for(core_id_set<-merged_buffer){	
        val min_core_id = core_id_set.min	
result = result:+(min_core_id,core_id_set)	
      }	
      result.iterator	
    })	
  rdd_merged	
}

//分区迭代计算,可以根据需要调整迭代次数和分区数量	
rdd_core = mergeRDD(rdd_core,8)	
rdd_core = mergeRDD(rdd_core,4)	
rdd_core = mergeRDD(rdd_core,1)	
rdd_core.collect.foreach(println)

640?wx_fmt=png

三,完整范例

完整范例还包括临时聚类簇的生成,以及最终聚类信息的整理。鉴于该部分代码较为冗长,在当前文章中不展示全部代码,仅说明最终结果。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值