ML----KNN算法----Spark实现

KNN算法思想

KNN(k-NearestNeighbor)又被称为最近邻算法,它的核心思想是:物以类聚,人以群分。KNN算法是机器学习 中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来 代表。KNN是一种分类算法,KNN没有显式的学习过程,也就是说没有训练阶段,待收到新样本后直接进行处理。
KNN的思路是:如果一个样本在特征空间中的k个最邻近的样本中的大多数属于某一个类别,则该样本也划分为这 个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个 样本的类别来决定待分样本所属的类别。
提到KNN,网上最常见的就是下面这个图,可以帮助大家理解。我们要确定绿点属于哪个颜色(红色或者蓝色), 要做的就是选出距离目标点距离最近的k个点,看这k个点的大多数颜色是什么颜色。当k取3的时候,我们可以看出 距离最近的三个,分别是红色、红色、蓝色,因此判定目标点为红色。
KNN

KNN算法描述

1)分别读取测试数据、训练数据集;
2)计算测试数据与训练数据之间的距离;
3)选取距离最小的K个点;
4) 确定前K个点所在类别的出现频率;
5)返回前K个点中出现频率最高的类别作为测试数据的预测分类

Spark实现

/**
  * 数据集:鸢尾花数据集
  *
  * Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。
  * 数据集内包含 3 类共 150 条 记录,每类各 50 个数据,
  * 每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,
  * 可以通过这4个 特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。
  *
  * 数据:
  * 6.2,3.4,5.4,2.3,virginica
  * 5.9,3,5.1,1.8,virginica 
  * 5.8,3.4,2.6,2.2
  * 5.5,2.3,3.3,1.9  
  * 5.3,3.7,1.5,0.2,setosa
  * 5,3.3,1.4,0.2,setosa
  * 7,3.2,4.7,1.4,versicolor
  * .........等
  */

KNN算法实现

package algorithm.MachineLearning
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}

object SimpleKNN {

  def main(args: Array[String]): Unit = {
  
    //1.初始化
    val conf=new SparkConf().setAppName("SimpleKnn").setMaster("local[*]")
    val sc=new SparkContext(conf)
    val K=15

    //2.读取数据,封装数据
    val data: RDD[LabelPoint] = sc.textFile("file:///H:\\IDEA2019_WorkSpace\\SparkLearning\\src\\main\\data\\iris.csv")
      .map(line => {
        val arr = line.split(",")
        if (arr.length == 5) {
          LabelPoint(arr.last, arr.init.map(_.toDouble))
        } else {
          LabelPoint(" ", arr.map(_.toDouble))
        }
      })

      //3.过滤出样本数据和测试数据
      val sampleData=data.filter(_.label!=" ")
    val testData=data.filter(_.label==" ").map(_.point).collect()

    //4.求每一条测试数据与样本数据的距离
    testData.foreach(elem=>{
      val distance=sampleData.map(x=>(getDistance(elem,x.point),x.label))
      //获取距离最近的k个样本
      val minDistance=distance.sortBy(_._1).take(K)
      //取出这k个样本的label并且获取出现最多的label即为测试数据的label
      val labels=minDistance.map(_._2)
          .groupBy(x=>x)
          .mapValues(_.length)
          .toList
          .sortBy(_._2).reverse
          .take(1)
          .map(_._1)
      printf(s"${elem.toBuffer.mkString(",")},${labels.toBuffer.mkString(",")}")
      println()
    })
    sc.stop()

  }

  case class LabelPoint(label:String,point:Array[Double])

  import scala.math._

  def getDistance(x:Array[Double],y:Array[Double]):Double={
    sqrt(x.zip(y).map(z=>pow(z._1-z._2,2)).sum)
  }


}

KNN算法优化

package algorithm.MachineLearning

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
import scala.collection.immutable.TreeSet

object SuperKNN {

  def main(args: Array[String]): Unit = {

    //1.初始化
    val conf=new SparkConf().setAppName("SimpleKnn").setMaster("local[*]")
    val sc=new SparkContext(conf)
    val K=15

    //2.读取数据,封装数据
    val data: RDD[LabelPoint] = sc.textFile("file:///H:\\IDEA2019_WorkSpace\\SparkLearning\\src\\main\\data\\iris.csv")
      .map(line => {
        val arr = line.split(",")
        if (arr.length == 5) {
          LabelPoint(arr.last, arr.init.map(_.toDouble))
        } else {
          LabelPoint(" ", arr.map(_.toDouble))
        }
      })

    //3.过滤出样本数据和测试数据
    val sampleData=data.filter(_.label!=" ")
    val testData=data.filter(_.label==" ").map(_.point).collect()

    //4.将testData封装到广播变量做一个优化
    val bc_testData=sc.broadcast(testData)

    //5.求每一条测试数据与样本数据的距离----使用mapPartitions应对大量数据集进行优化
    val distance: RDD[(String,(Double,String))] = sampleData.mapPartitions(iter => {
      val bc_points = bc_testData.value
      iter.flatMap(x => bc_points.map(point2 => (point2.mkString(","), (getDistance(point2, x.point),x.label))))
    })

    //6.求距离最小的k个点,使用aggregateByKey---先分局内聚合,再全局聚合
    distance.aggregateByKey(TreeSet[(Double,String)]())(
      (splitSet:TreeSet[(Double,String)],elem:(Double,String))=>{
        val newSet=splitSet+elem   //TreeSet默认是有序的(升序)
        newSet.take(K)
      },	
	(splitSet1:TreeSet[(Double,String)],splitSet2:TreeSet[(Double,String)])=>{
        (splitSet1 ++ splitSet2).take(K)
      }
    )
      //7.取出距离最小的k个点中出现次数最多的label---即为样本数据的label
      .map(x=>{
      (
        x._1,
        x._2.toArray.map(_._2).groupBy(y=>y).map(z=>(z._1,z._2.length)).toList.sortBy(_._2).map(_._1).take(1).mkString(",")
      )
    }).foreach(x=> println(x))

    sc.stop()
  }


  case class LabelPoint(label:String,point:Array[Double])

  import scala.math._

  def getDistance(x:Array[Double],y:Array[Double]):Double={
    sqrt(x.zip(y).map(z=>pow(z._1-z._2,2)).sum)
  }

}

KNN算法优缺点

  • 优点

(1)理论成熟,思想简单,既可以用来做分类也可以用来做回归;
(2)可用于非线性分类;
(3)训练时间复杂 度低,为O(n);
(4) 和朴素贝叶斯之类的算法比,对数据没有假设,准确度高,对异常点不敏感;

  • 缺点

(1)计算量大,尤其是特征数非常多的时候;
(2)样本不平衡的时候,对稀有类别的预测准 确率低;
(3)kd树,球树之类的模型建立需要大量的内存;
(4)使用懒散学习方法,基本上不学习,导致预测 时速度比起逻辑回归之类的算法慢;
(5)KNN模型可解释性不强。

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值