import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
/**
* @author XiaoTangBao
* @date 2019/4/10 11:34
* @version 1.0
*/
object DBSCAN {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
val spark = SparkSession.builder().master("local[4]").appName("DBSCAN").getOrCreate()
val data = spark.sparkContext.textFile("G:\\mldata\\watermelon 4.0.txt").map(line => line.split('|')).map(arr => (arr(0), arr(1)))
.map(tp => (tp._1.toDouble, tp._2.toDouble))
.collect()
//设置邻域
val area = 0.11
//设置MinPts
val MinPts = 5
// //初始化核心对象集omg
var omg = ArrayBuffer[(Double, Double)]()
for (i <- 0 until data.length) {
var num = 0
for (j <- 0 until data.length) {
val distance = math.sqrt(math.pow((data(i)._1 - data(j)._1), 2) + math.pow((data(i)._2 - data(j)._2), 2))
if (distance < area) num += 1
}
if (num >= MinPts) omg.append(data(i))
}
//数据集中噪音点
val noiseData = ArrayBuffer[(Double, Double)]()
//数据集中非核心对象
var nonCore_Data = ArrayBuffer[(Double, Double)]()
for (i <- 0 until data.length) {
if (!omg.contains(data(i))) nonCore_Data.append(data(i))
}
for (i <- 0 until nonCore_Data.length) {
//数据点的邻域范围内没有核心点,则为噪音数据
var noise_flag = true
var num = 0
for (og <- omg if noise_flag) {
num += 1
val noise_dis = math.sqrt(math.pow((nonCore_Data(i)._1 - og._1), 2) + math.pow((nonCore_Data(i)._2 - og._2), 2))
if (noise_dis <= area) noise_flag = false
}
val last_noise_dis = math.sqrt(math.pow((nonCore_Data(i)._1 - omg(omg.length - 1)._1), 2) + math.pow((nonCore_Data(i)._2 - omg(omg.length - 1)._2), 2))
if (last_noise_dis <= area) num -= 1
if (num == omg.length) noiseData.append(nonCore_Data(i))
}
//从nonCore_Data中剔除noiseData
val new_nonCore_Data = reflash(nonCore_Data, noiseData)
//最终的分类结果:
val result = ArrayBuffer[ArrayBuffer[(Double,Double)]]()
//获取聚类簇
while (omg.length != 0) {
//获取随机种子
println("omg.length:"+omg.length)
var k = (math.random * (omg.length + 1)).toInt - 1
k = if(k > 0) k else 0
println("生成的随机数为:"+k)
var sed = omg(k)
//该种子所对应的簇内数据
val sed_data = ArrayBuffer[(Double, Double)]()
sed_data.append(sed)
omg = reflash(omg,ArrayBuffer(sed))
//初始化队列
val quen = mutable.Queue[(Double, Double)]()
quen.enqueue(sed)
while (!quen.isEmpty) {
val new_sed = quen.dequeue()
val datas = find(new_sed,nonCore_Data,omg,area)
val direct_nonData = datas(0)
val direct_CoreData = datas(1)
for(dn <- direct_nonData) sed_data.append(dn)
for(dc <- direct_CoreData) sed_data.append(dc)
//更新核心对象集
omg = reflash(omg,direct_CoreData)
//更新非核心对象集
nonCore_Data = reflash(nonCore_Data,direct_nonData)
//采用分层遍历的方式,将新的CoreData加入quen队列
for(dc <- direct_CoreData) quen.enqueue(dc)
}
result.append(sed_data)
}
for(k<-0 until result.length){
println("簇"+k+"的聚类结果")
result(k).foreach(println(_))
}
}
/**
* @param sed random sed
* @param nonCore_Data non core Data
* @param omg core Data set
* @param area distance parameter of DBSCAN
* @return include all the direct reachable data set,it include core data and non core data
*/
def find(sed: (Double, Double), nonCore_Data: ArrayBuffer[(Double, Double)], omg: ArrayBuffer[(Double, Double)], area: Double): ArrayBuffer[ArrayBuffer[(Double, Double)]] = {
//第一步,找出所有的密度直达的非核心对象
val direct_nonData = ArrayBuffer[(Double, Double)]()
for (non <- nonCore_Data) {
val non_dis = math.sqrt(math.pow(non._1 - sed._1, 2) + math.pow(non._2 - sed._2, 2))
if (non_dis <= area) {
direct_nonData.append(non)
}
}
//第二步,找出所有密度直达的核心对象
val direct_CoreData = ArrayBuffer[(Double, Double)]()
for (omg_find <- omg) {
if (omg_find._1 != sed._1 && omg_find._2 != sed._2) {
val dis = math.sqrt(math.pow(omg_find._1 - sed._1, 2) + math.pow(omg_find._2 - sed._2, 2))
if (dis <= area) {
direct_CoreData.append(omg_find)
}
}
}
//返回所有的密度可达的核心对象和非核心对象
val all_Data = ArrayBuffer[ArrayBuffer[(Double, Double)]]()
all_Data.append(direct_nonData)
all_Data.append(direct_CoreData)
all_Data
}
/** 更新数组,从a去除掉b所对应的元素
*
* @param a waiting to reflash
* @param b a will remove the elements which b contains
* @return a after reflah
*/
def reflash(a: ArrayBuffer[(Double, Double)], b: ArrayBuffer[(Double, Double)]) = {
val c = ArrayBuffer[(Double, Double)]()
for (ai <- a) {
if (!b.contains(ai)) c.append(ai)
}
c
}
}
原始数据集:
DBSCAN聚类后: