Spark MLlib KMeans聚类算法

本文详细介绍了Spark MLlib中的KMeans聚类算法,包括基础理论、过程演示和源码分析。KMeans算法通过迭代寻找最优聚类中心,Spark Mllib实现了高效的KMeans算法,提供设置聚类个数、最大迭代次数等参数的功能。
摘要由CSDN通过智能技术生成

1.1 KMeans聚类算法

1.1.1 基础理论

KMeans算法的基本思想是初始随机给定K个簇中心,按照最邻近原则把待分类样本点分到各个簇。然后按平均法重新计算各个簇的质心,从而确定新的簇心。一直迭代,直到簇心的移动距离小于某个给定的值。

K-Means聚类算法主要分为三个步骤:

(1)第一步是为待聚类的点寻找聚类中心;

(2)第二步是计算每个点到聚类中心的距离,将每个点聚类到离该点最近的聚类中去;

(3)第三步是计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心;

反复执行(2)、(3),直到聚类中心不再进行大范围移动或者聚类次数达到要求为止。

1.1.2过程演示

下图展示了对n个样本点进行K-means聚类的效果,这里k取2:

(a)未聚类的初始点集;

(b)随机选取两个点作为聚类中心;

(c)计算每个点到聚类中心的距离,并聚类到离该点最近的聚类中去;

(d)计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心;

(e)重复(c),计算每个点到聚类中心的距离,并聚类到离该点最近的聚类中去;

(f)重复(d),计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心。

参照以下文档:

http://blog.sina.com.cn/s/blog_62186b46010145ne.html

1.2 Spark Mllib KMeans源码分析

class KMeansprivate (

    privatevar k: Int,

    privatevar maxIterations: Int,

    privatevar runs: Int,

    privatevar initializationMode: String,

    privatevar initializationSteps: Int,

    privatevar epsilon: Double,

    privatevar seed: Long)extends Serializablewith Logging {

// KMeans类参数:

k:聚类个数,默认2maxIterations:迭代次数,默认20runs:并行度,默认1

initializationMode:初始中心算法,默认"k-means||"initializationSteps:初始步长,默认5epsilon:中心距离阈值,默认1e-4seed:随机种子。

  /**

   * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,

   * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.

   */

  defthis() =this(2,20, 1, KMeans.K_MEANS_PARALLEL,5, 1e-4, Utils.random.nextLong())

// 参数设置

/** Set the number of clusters to create (k). Default: 2. */

  def setK(k: Int):this.type = {

    this.k = k

    this

  }

**省略各个参数设置代码**

// run方法,KMeans主入口函数

  /**

   * Train a K-means model on the given set of points; `data` should be cached for high

   * performance, because this is an iterative algorithm.

   */

  def run(data: RDD[Vector]): KMeansModel = {

 

    if (data.getStorageLevel == StorageLevel.NONE) {

      logWarning("The input data is not directly cached, which may hurt performance if its"

        + " parent RDDs are also uncached.")

    }

 

// Compute squared norms and cache them.

// 计算每行数据的L2范数,数据转换:data[Vector]=> data[(Vector, norms)],其中norms是Vector的L2范数,norms就是

    val norms = data.map(Vectors.norm(_,2.0))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值