Spark之K-近邻算法
关于K-近邻算法的详细描述可以看
MapReduce之KNN算法
简而言之,K近邻算法即根据已经分类好的数据,通过特定的方式进行对比,对为分类的数据进行分类,Spark程序如下所示
package KNN
import org.apache.spark.{SparkConf, SparkContext}
object KNN {
def main(args: Array[String]):Unit={
val sparkConf=new SparkConf().setAppName("KNN").setMaster("local")
val sc=new SparkContext(sparkConf)
val k=2 //KNN中的k
val d=2 //向量唯独
val inputDatasetR="input/R.txt" //查询数据集文件路径
val inputDatasetS="input/S.txt" //训练数据集文件路径
val output="output"
//广播共享对象
val broadcastK=sc.broadcast(k)
val broadcastD=sc.broadcast(d)
//创建RDD
val R=sc.textFile(inputDatasetR)
val S=sc.textFile(inputDatasetS)
//计算向量之间的距离函数
def calculateDistance(rAsString: String, sAsString: String,d: Int):Double={
val r=rAsString.split(",").map(_.toDouble)
val s=sAsString.split(",").map(_.toDouble)
if(r.length!=d||s.length!=d) Double.NaN else{
math.sqrt((r,s).zipped.take(d).map{
case(ri,si)=>math.pow((ri-si),2)
}.reduce(_+_))
}
}
//R对S做笛卡尔积
val cart=R cartesian S
//找到R中的r与S中的s之间的距离
val knnMapped = cart.map(cartRecord => {
val rRecord = cartRecord._1
val sRecord = cartRecord._2
val rTokens = rRecord.split(";")
val rRecordID = rTokens(0)
val r = rTokens(1)
val sTokens = sRecord.split(";")
val sClassificationID = sTokens(1)
val s = sTokens(2)
val distance = calculateDistance(r, s, broadcastD.value)
(rRecordID, (distance, sClassificationID))
})
//按照R中的r对距离进行分组
val knnGrouped = knnMapped.groupByKey()
//找到k个邻近并对r分类
val knnOutput = knnGrouped.mapValues(itr => {
val nearestK = itr.toList.sortBy(_._1).take(broadcastK.value)
val majority = nearestK.map(f => (f._2, 1)).groupBy(_._1).mapValues(list => {
val (stringList, intlist) = list.unzip
intlist.sum
})
majority.maxBy(_._2)._1
})
//打印结果
knnOutput.foreach(println)
sc.stop()
}
}