CH3_K近邻(KNN)算法及其Spark实现

KNN简介

KNN(k-Nearest Neighbors)又称作k-近邻。k-nn就是把未标记分类的案列归为与它们最相似的带有分类标记的案例所在的类。

KNN的特点

优点缺点
简单且有效不产生模型
训练阶段很快分类过程比较慢
对数据分布无要求模型解释性较差
适合稀疏时间和多分类问题名义变量和缺失数据需要额外处理

KNN模型

K近邻模型有三个基本要素:距离度量、K值的选择、分类决策规则

实现步骤
  1. 计算距离:计算待测案例与训练样本之间的距离 。
  2. 选择一个合适的k:确定用于KNN算法的邻居数量,一般用交叉验证或仅凭经验选择一个合适的k值,待测案例与训练样本之间距离最小的k个样本组成一个案例池。
  3. 类别判定:根据案例池的数据采用投票法或者加权投票法等方法来决定待测案例所属的类别。
KD-Tree

kd-tree是一种分割k维数据空间的数据结构。主要应用于多维空间数据的搜索,经常使用在SIFT、KNN等多维数据搜索的场景中,以KNN(K近邻)为例,使用线性搜索的方式效率低下,k-d树本质是对多维空间的划分,其每个节点都为k维点的二叉树kd-tree,因此可以大大提高搜索效率。详细的构造方法和kd树的最近邻搜索方法可以参考李航老师的《统计学习方法》。

1.定义Kd树类及其方法
package CH3_KNearestNeibor

/**
  * Created by WZZC on 2019/11/29
  **/
/**
  *
  * @param label 分类指标
  *  @param value 节点数据
  *  @param dim   当前切分维度
  *  @param left  左子节点
  *  @param right 右子节点
  */
case class TreeNode(label: String,
                    value: Seq[Double],
                    dim: Int,
                    var left: TreeNode,
                    var right: TreeNode)
    extends Serializable {}

object TreeNode {
  import statisticslearn.DataUtils.distanceUtils._

  /**
    *创建KD 树
    *
    * @param value
    * @param dim
    * @param shape
    * @return
    */
  def creatKdTree(value: Seq[(String, Seq[Double])],
                  dim: Int,
                  shape: Int): TreeNode = {

    // 数据按照当前划分的维度排序
    val sorted: Seq[(String, Seq[Double])] = value.sortBy(tp2 => tp2._2(dim))

    //中间位置的索引
    val midIndex: Int = value.length / 2

    sorted match {
      // 当节点为空时,返回null
      case Nil => null

      //节点不为空时,递归调用方法
      case _ =>
        val left: Seq[(String, Seq[Double])] = sorted.slice(0, midIndex)
        val right: Seq[(String, Seq[Double])] =
          sorted.slice(midIndex + 1, value.length)

        val leftNode = creatKdTree(left, (dim + 1) % shape, shape) //左子节点递归创建树
        val rightNode = creatKdTree(right, (dim + 1) % shape, shape) //右子节点递归创建树

        TreeNode(
          sorted(midIndex)._1,
          sorted(midIndex)._2,
          dim,
          leftNode,
          rightNode
        )

    }
  }

  /**
    * 从root节点开始,DFS搜索直到叶子节点,同时在stack中顺序存储已经访问的节点。
    * 如果搜索到叶子节点,当前的叶子节点被设为最近邻节点。
    * 然后通过stack回溯:
    * 如果当前点的距离比最近邻点距离近,更新最近邻节点.
    * 然后检查以最近距离为半径的圆是否和父节点的超平面相交.
    * 如果相交,则必须到父节点的另外一侧,用同样的DFS搜索法,开始检查最近邻节点。
    * 如果不相交,则继续往上回溯,而父节点的另一侧子节点都被淘汰,不再考虑的范围中.
    * 当搜索回到root节点时,搜索完成,得到最近邻节点。
    *
    * @param treeNode
    * @param data
    * @param k
    * @return
    */
  def knn(treeNode: TreeNode, data: Seq[Double], k: Int = 1) = {

    //    implicit def vec2Seq(a:DenseVector[Double])=a.toArray.toSeq

    var resArr = new Array[(Double, TreeNode)](k)
      .map(_ => (Double.MaxValue, null))
      .asInstanceOf[Array[(Double, TreeNode)]]

    def finder(treeNode: TreeNode): TreeNode = {

      if (treeNode != null) {
        val dimr = data(treeNode.dim) - treeNode.value(treeNode.dim)
        if (dimr > 0) finder(treeNode.right) else finder(treeNode.left)

        val distc: Double = euclidean(treeNode.value, data)

        if (distc < resArr.last._1) {
          resArr.update(k - 1, (distc, treeNode))
          resArr = resArr.sortBy(_._1)
        }

        if (math.abs(dimr) < resArr.last._1)
          if (dimr > 0) finder(treeNode.left) else finder(treeNode.right)

      }
      resArr.last._2
    }

    finder(treeNode)
    resArr

  }

}

2.Spark实现 Knn模型
package CH3_KNearestNeibor

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._


/**
  * Created by WZZC on 2019/11/29
  **/
case class KnnModel(data: DataFrame, labelName: String) extends Serializable {

  private val spark = data.sparkSession

//  import spark.implicits._
  // 使用.rdd的时候不能使用 col
//  private val sfadsfaggaggsagafasavsa: String = UUID.randomUUID().toString

  private val ftsName: String = Identifiable.randomUID("KnnModel")

  // 数据特征名称
  private val fts: Array[String] = data.columns.filterNot(_ == labelName)

  val shapes: Int = fts.length

  def vec2Seq = udf((vec: DenseVector) => vec.toArray.toSeq)

  /**
    *
    * @param dataFrame
    * @return
    */
  def dataTransForm(dataFrame: DataFrame) = {
    new VectorAssembler()
      .setInputCols(fts)
      .setOutputCol(ftsName)
      .transform(dataFrame)
  }

  private val kdtrees: Array[TreeNode] = dataTransForm(data)
    .withColumn(ftsName, vec2Seq(col(ftsName)))
    .select(labelName, ftsName)
    .withColumn("partitionIn", spark_partition_id())
    .rdd //在大数据情况下,分区构建kdtree
    .map(row => {
      val partitionIn = row.getInt(2)
      val label = row.getString(0)
      val features = row.getAs[Seq[Double]](1)
      (partitionIn, label, features)
    })
    .groupBy(_._1)
    .mapValues(_.toSeq.map(tp3 => (tp3._2, tp3._3)))
    .mapValues(nn => TreeNode.creatKdTree(nn, 0, shapes))
    .values
    .collect()


  /**
    *
    * @param predictDf
    * @param k
    * @return
    */
  def predict(predictDf: DataFrame, k: Int): DataFrame = {

    // 此处方法重载需要注意:overloaded method needs result type
    def nsearchUdf = udf((seq: Seq[Double]) => predict(seq, k))

    dataTransForm(predictDf)
      .withColumn(ftsName, vec2Seq(col(ftsName)))
      .withColumn(labelName, nsearchUdf(col(ftsName)))
      .drop(ftsName)

  }

  /**
    *
    * @param predictData
    * @param k
    * @return
    */
  def predict(predictData: Seq[Double], k: Int): String = {

    // 查询的时候遍历每个kdtree,然后取结果集再排序
    val res: Array[(Double, Seq[Double], String)] = kdtrees
      .map(node => {
        TreeNode.knn(node, predictData, k)
          .map(tp2 => (tp2._1, tp2._2.value, tp2._2.label))
      })
      .flatMap(_.toSeq)
      .sortBy(_._1)
      .take(k)

    // 按照投票选举的方法选择分类结果
    val cl = res
      .map(tp3 => (tp3._3, 1))
      .groupBy(_._1)
      .mapValues(_.map(_._2).sum)
      .maxBy(_._2)
      ._1
    cl
  }

 
}

3.算法测试
package CH3_KNearestNeibor

import org.apache.spark.sql.SparkSession

/**
  * Created by WZZC on 2019/11/29
  **/
object KNNRunner {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession
      .builder()
      .appName(s"${this.getClass.getSimpleName}")
      .master("local[*]")
      .getOrCreate()

    val iris = spark.read
      .option("inferSchema", true)
      .option("header", true)
      .csv("/data/iris.csv")

    val model: KnnModel = KnnModel(iris, "class")

    model.predict(iris, 3).show()

    spark.stop()

  }

}

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 19
    评论
以下是一个简单的 Java 实现 k 近邻算法: ```java import java.util.*; public class KNN { private List<Data> trainSet; public KNN(List<Data> trainSet) { this.trainSet = trainSet; } public String classify(Data testData, int k) { List<Data> nearestNeighbors = findNearestNeighbors(testData, k); String label = getMostFrequentLabel(nearestNeighbors); return label; } private List<Data> findNearestNeighbors(Data testData, int k) { PriorityQueue<Data> pq = new PriorityQueue<>(k, new Comparator<Data>() { @Override public int compare(Data d1, Data d2) { return Double.compare(d2.distance, d1.distance); } }); for (Data trainData : trainSet) { double distance = getDistance(testData, trainData); trainData.distance = distance; if (pq.size() < k) { pq.offer(trainData); } else { Data top = pq.peek(); if (distance < top.distance) { pq.poll(); pq.offer(trainData); } } } List<Data> result = new ArrayList<>(); while (!pq.isEmpty()) { result.add(pq.poll()); } return result; } private String getMostFrequentLabel(List<Data> nearestNeighbors) { Map<String, Integer> countMap = new HashMap<>(); for (Data data : nearestNeighbors) { String label = data.label; countMap.put(label, countMap.getOrDefault(label, 0) + 1); } String label = null; int maxCount = 0; for (Map.Entry<String, Integer> entry : countMap.entrySet()) { if (entry.getValue() > maxCount) { label = entry.getKey(); maxCount = entry.getValue(); } } return label; } private double getDistance(Data d1, Data d2) { double sum = 0.0; for (int i = 0; i < d1.features.length; i++) { sum += Math.pow(d1.features[i] - d2.features[i], 2); } return Math.sqrt(sum); } private static class Data { String label; double[] features; double distance; public Data(String label, double[] features) { this.label = label; this.features = features; } } } ``` 使用方法: ```java List<KNN.Data> trainSet = new ArrayList<>(); trainSet.add(new KNN.Data("A", new double[]{1, 2})); trainSet.add(new KNN.Data("A", new double[]{2, 3})); trainSet.add(new KNN.Data("B", new double[]{3, 4})); trainSet.add(new KNN.Data("B", new double[]{4, 5})); KNN knn = new KNN(trainSet); KNN.Data testData = new KNN.Data(null, new double[]{1.5, 2.5}); String label = knn.classify(testData, 3); System.out.println(label); // 输出 "A" ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值