Streaming做KMeans、实时KMeans算法

Streaming是怎么做KMeans的?

大家好,我是一拳就能打爆你A柱的男人

大家在学机器学习的时候一定看过K-Means算法,但是各位有没有想过在实时计算的时候是如何做K-Means的呢?接下来我打算从下面几个方面来给大家梳理一下:1、K-Means算法原理,2、Streaming K-Means手算,3、Streaming K-Means源码解读。

关键词:StreamingKmeans,实时KMeans,源码解读StreamingKmeans

1、 K-Means算法原理

关于K-Means算法我之前有一篇博客也讲过,并且附带案例。各位有兴趣的可以去看一看:K-Means算法及相关案例 。接下来我还是简单讲一下吧。

首先请看下面这张图:

在这里插入图片描述

​ 图1来源:KMeans 算法(一)

  • a:显然有两个簇的样本,但是此时并没有做出聚类。
  • b:随机生成了红蓝两个点(即簇心)
  • c:此时开始迭代(c~f步)每个样本都会分别计算样本自身到两个簇心的距离,并将自身归类为最近簇心一类。
    • d:在c中样本分成红蓝两色,但是显然没有分类完全,所以需要更新簇心位置,簇心位置由该类中所有样本的位置(向量)决定。
    • e:经过d的簇心位置计算,样本已经很靠近自身簇的样本点,此时再次更新簇心。
    • f:尽管e中在我们看来分类完成,但是并没有达到所设定的阈值,所以还需要进一步计算簇心、更新点类别… 当达到阈值时就如f所看到的。

OK,经过上面简单的讲解,大家应该有所理解K-Means算法了,总结下来就是:计算样本到簇心的距离更新样本的簇、通过样本位置更新簇心、判断阈值、再次更新下一轮。

2、 Streaming K-Means手算

2.1 Spark Streaming中K-Means的介绍

Streaming k-means 中对实时计算K-Means有介绍:

c t + 1 = c t n t α + x t m t n t α + m t c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} ct+1=ntα+mtctntα+xtmt

n t + 1 = n t + m t n_{t+1} = n_t+m_t nt+1=nt+mt

  • ct是t代簇心,是一个向量。
  • nt是属于t代簇的样本数,是Int。
  • xt是新batch的簇心,是一个向量。
  • mt是新batch中被加入ct簇的样本数,是Int。
  • α是遗忘因子,范围a∈(0~1)

可能上面这个公式你会有很多疑问,为什么有a、Kmeans不是用欧式距离更新簇心的吗?下面我将手算给大家看。

2.2 手算Streaming K-Means

现在假设是第0代,k=2,所以需要随机初始化簇心,并且加入一个新的batch来迭代。我将用a=0a=1a=0.5 三个参数来帮助大家理解a遗忘因子,数据维度是2。

2.2.1 明确参数
  • c_t作为随机初始化的簇心,我为了方便定[[0,0],[10,10]]
  • batch作为新到达的一批样本点,我定[[0,1],[1,0],[2,1],[9,9]]
  • a分别取 a=0a=1a=0.5
2.2.2 初始图

在这里插入图片描述

蓝色是初始化簇心,红色为新batch的样本点。

2.2.3 迭代计算

在新batch进入的时候首先会计算离自身最近的簇心,然后分别定ntxtmt这些参数。我们以[0,0]这个簇为例,nt=0xt=[1,2/3]mt=3

2.2.3.1 第一次迭代a=0

代入
c t + 1 = c t n t α + x t m t n t α + m t c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} ct+1=ntα+mtctntα+xtmt
知道
c t + 1 = [ 0 , 0 ] ∗ 0 ∗ 0 + [ 1 , 2 / 3 ] ∗ 3 0 ∗ 0 + 3 = [ 1 , 2 / 3 ] c_{t+1} =\frac{[0,0]*0*0+[1,2/3]*3}{0*0+3} = [1,2/3] ct+1=00+3[0,0]00+[1,2/3]3=[1,2/3]

2.2.3.2 第一次迭代a=0.5

代入
c t + 1 = c t n t α + x t m t n t α + m t c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} ct+1=ntα+mtctntα+xtmt
知道
c t + 1 = [ 0 , 0 ] ∗ 0 ∗ 0.5 + [ 1 , 2 / 3 ] ∗ 3 0 ∗ 0.5 + 3 = [ 1 , 2 / 3 ] c_{t+1} = \frac{[0,0]*0*0.5+[1,2/3]*3}{0*0.5+3} = [1,2/3] ct+1=00.5+3[0,0]00.5+[1,2/3]3=[1,2/3]

2.2.3.3 第一次迭代a=1

c t + 1 = [ 0 , 0 ] ∗ 0 ∗ 1 + [ 1 , 2 / 3 ] ∗ 3 0 ∗ 1 + 3 = [ 1 , 2 / 3 ] c_{t+1} = \frac{[0,0]*0*1+[1,2/3]*3}{0*1+3} = [1,2/3] ct+1=01+3[0,0]01+[1,2/3]3=[1,2/3]

在这里插入图片描述

此时新簇心好像都一样,但是nt什么的已经改变,再来一个batch看看,新batch我定[[-1,0],[0,-1],[11,9]],跟上一次迭代一样,首先通过距离计算得知左下角的簇新加入的是两个黄色的样本点。

再次重申一下参数,记得算一下n_{t+1}

  • c_{t+1}=[1,2/3]
  • n_{t+1}=3
  • x_{t+1}=[-0.5,-0.5]
  • m_{t+1}=2
2.2.3.4 第二次迭代a=0

代入公式:
c t + 2 = c t + 1 n t + 1 α + x t + 1 m t + 1 n t + 1 α + m t + 1 = [ 1 , 2 / 3 ] ∗ 3 ∗ 0 + [ − 0.5 , − 0.5 ] ∗ 2 3 ∗ 0 + 2 = [ − 0.5 , − 0.5 ] c_{t+2} = \frac{c_{t+1}n_{t+1}\alpha + x_{t+1}m_{t+1}}{n_{t+1}\alpha+m_{t+1}} = \frac{[1,2/3]*3*0+[-0.5,-0.5]*2}{3*0+2} = [-0.5,-0.5] ct+2=nt+1α+mt+1ct+1nt+1α+xt+1mt+1=30+2[1,2/3]30+[0.5,0.5]2=[0.5,0.5]

2.2.3.5 第二次迭代a=0.5

代入公式:
c t + 2 = c t + 1 n t + 1 α + x t + 1 m t + 1 n t + 1 α + m t + 1 = [ 1 , 2 / 3 ] ∗ 3 ∗ 0.5 + [ − 0.5 , − 0.5 ] ∗ 2 3 ∗ 0.5 + 2 = [ 1 7 , 0 ] c_{t+2} = \frac{c_{t+1}n_{t+1}\alpha + x_{t+1}m_{t+1}}{n_{t+1}\alpha+m_{t+1}} = \frac{[1,2/3]*3*0.5+[-0.5,-0.5]*2}{3*0.5+2} = [\frac{1}{7},0] ct+2=nt+1α+mt+1ct+1nt+1α+xt+1mt+1=30.5+2[1,2/3]30.5+[0.5,0.5]2=[71,0]

2.2.3.6 第二次迭代a=1

代入公式:
c t + 2 = c t + 1 n t + 1 α + x t + 1 m t + 1 n t + 1 α + m t + 1 = [ 1 , 2 / 3 ] ∗ 3 ∗ 1 + [ − 0.5 , − 0.5 ] ∗ 2 3 ∗ 1 + 2 = [ 2 5 , 1 5 ] c_{t+2} = \frac{c_{t+1}n_{t+1}\alpha + x_{t+1}m_{t+1}}{n_{t+1}\alpha+m_{t+1}} = \frac{[1,2/3]*3*1+[-0.5,-0.5]*2}{3*1+2} = [\frac{2}{5},\frac{1}{5}] ct+2=nt+1α+mt+1ct+1nt+1α+xt+1mt+1=31+2[1,2/3]31+[0.5,0.5]2=[52,51]

OK,现在我们可以看出差距了。a作为遗忘因子在更新簇心的过程中起到至关重要的作用。当a=0时,簇是完全无记忆的,每次更新新簇心都由新batch的点来决定;当a=1时,簇是完全无法遗忘的,每次更新由新batch的点和以往所有点决定;当a∈(0,1)时,簇有部分遗忘的能力,且只对以往所有点遗忘,a越接近0记忆越差。

总结

综上,Spark官方文档关于Streaming K-Means的解释我们已经了解,它类似而又不同于单机离线Kmeans。Streaming K-Means更新由每一个batch的样本点决定,且以往的数据最终只表现在簇心,而单机离线Kmeans没有遗忘因子,每一次迭代虽然会改变簇心的位置,但是同样也会被所有样本左右。

无论

3、Streaming K-Means源码解读

接下来就到最难啃的阶段,我们尝试去看一看StreamingKMeans的源码,看他在代码里是如何操作的,一方面来印证之前的想法是否正确,另一方面也希望学习到其他新的东西。

在这里我们依旧使用Spark官方的示例:

import org.apache.spark.mllib.clustering.StreamingKMeans
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.streaming.{Seconds, StreamingContext}

val conf = new SparkConf().setAppName("StreamingKMeansExample")
val ssc = new StreamingContext(conf, Seconds(args(2).toLong))

val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse)
val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)

val model = new StreamingKMeans()
  .setK(args(3).toInt)
  .setDecayFactor(1.0)
  .setRandomCenters(args(4).toInt, 0.0)

model.trainOn(trainingData)
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()

ssc.start()
ssc.awaitTermination()

由于我直到现在还是没法正常运行这个示例,所以没法用断点的方式查看内部的实际运行状态以及数据类型,所以我是通过纯用脑的方式来“跑”这份代码的。可能有些问题,大家有什么问题也可以在文章下留言。

第一第二行的confssc都不用说了,配置一些环境。

3.1 读取数据

接下来

val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse)
val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)

显然这两行代码是读取数据,但是具体什么格式的数据,并且map后会变成什么样呢?

在scala示例旁边还有个python示例,如图:

在这里插入图片描述

点击可以看到:

trainingData = sc.textFile("data/mllib/kmeans_data.txt").map(lambda line:Vectors.dense([float(x) for x in line.strip().split(' ')]))

testingData = sc.textFile("data/mllib/streaming_kmeans_data_test.txt").map(parse)

显然这就是数据,通过github上spark里查找可以得到数据格式如下:

# kmeans_data.txt
0.0 0.0 0.0
0.1 0.1 0.1
0.2 0.2 0.2
9.0 9.0 9.0
9.1 9.1 9.1
9.2 9.2 9.2
# streaming_kmeans_data_test.txt
(1.0), [1.7, 0.4, 0.9]
(2.0), [2.2, 1.8, 0.0]

其实我在这里也花了不少时间看了源码,内部还是有一些差异,但是为了保持文章的内容连贯性我直接给大家说结论:

  • ssc逐行读取数据(形如kmeans_data.txt),然后通过map方法对其中数据转成Vector对象,而这个对象打印出来是[0.0,0.0,0.0]这样的。
  • 同理,ssc读取streaming_kmeans_data_test.txt数据,将其转为LabeledPoint对象,这个对象有两个变量,一个是label:Double,另一个是features:Vector

所以读取数据就是将特定的数据格式转为mllib中的对象,方便接下来的操作。

3.2 生成模型

val model = new StreamingKMeans()
  .setK(args(3).toInt)
  .setDecayFactor(1.0)
  .setRandomCenters(args(4).toInt, 0.0)

这段代码显然是设置一些相关的参数了,比如kadimweight。可以理解成手算过程中第0次迭代。

当然,为了接下来更好地理解,我给大家看一下StreamingKMeans对象。

3.2.1 StreamingKMeans 对象
@Since("1.2.0")
class StreamingKMeans @Since("1.2.0") (
    @Since("1.2.0") var k: Int,
    @Since("1.2.0") var decayFactor: Double,
    @Since("1.2.0") var timeUnit: String) extends Logging with Serializable {
    
	@Since("1.2.0")
  	def this() = this(2, 1.0, StreamingKMeans.BATCHES)

  	protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)

可以看到StreamingKMeans有三个参数kdecayFactortimeUnit ,并且内部还有一个StreamingKMeansModel对象,这显然不是一样的东西。

经过简单的查看,这个类里都是一些设置参数的方法,**最重要的就是trainOnpredictOn方法,而这些方法无一例外的都调用了StreamingKMeansModel对象的updatepredict方法。**可见,真正的实现都放在StreamingKMeansModel里。

3.2.2 StreamingKMeansModel 对象

直接点开trainOn方法就可以到达StreamingKMeans调用StreamingKMeansModel的地方:

model.trainOn(trainingData)
/**
   * Update the clustering model by training on batches of data from a DStream.
   * This operation registers a DStream for training the model,
   * checks whether the cluster centers have been initialized,
   * and updates the model using each batch of data from the stream.
   *
   * @param data DStream containing vector data
   */
  @Since("1.2.0")
  def trainOn(data: DStream[Vector]) {
    assertInitialized()
    data.foreachRDD { (rdd, time) =>
      model = model.update(rdd, decayFactor, timeUnit)
    }
  }

很熟悉的,这里data我们终于可以看到其类型:data: DStream[Vector],那么刚才读取数据部分说的没错了。经过断言检查,DStream使用foreachRDD操作,显然在这里我们可以知道数据在不同的分区做统一的操作。

而每一个RDD都使用的是StreamingKMeansModelupdate方法,点进去就可以到达具体实现了,这里我先不贴这个方法的代码,我们还是先看一下StreamingKMeansModel对象有什么内容

@Since("1.2.0")
class StreamingKMeansModel @Since("1.2.0") (
    @Since("1.2.0") override val clusterCenters: Array[Vector],
    @Since("1.2.0") val clusterWeights: Array[Double])
  extends KMeansModel(clusterCenters) with Logging {

很惊讶,只有两个变量,一个是clusterCenters: Array[Vector],一个是clusterWeights: Array[Double],经过观察名称可以知道其含义:

  • clusterCenters: Array[Vector]:这个变量用Array来装,里面全是Vector,而且名字叫clusterCenters那么显然这个变量就是装的簇心向量。所以说每一次迭代都会改变这个簇心Array的内容。
  • clusterWeights: Array[Double]:这个变量也是Array装,且里面是Double类型,名字叫簇权重,目前还是很模糊。

3.3 StreamingKMeansModelupdate方法

看完StreamingKMeansModel对象,我们直接看update,给大家贴源码(很长,可以跳过源码看我讲解):

/**
   * Perform a k-means update on a batch of data.
   */
  @Since("1.2.0")
  def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {

    // find nearest cluster to each point
    val closest = data.map(point => (this.predict(point), (point, 1L)))

    // get sums and counts for updating each cluster
    val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
      BLAS.axpy(1.0, p2._1, p1._1)
      (p1._1, p1._2 + p2._2)
    }
    val dim = clusterCenters(0).size

    val pointStats: Array[(Int, (Vector, Long))] = closest
      .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
      .collect()

    val discount = timeUnit match {
      case StreamingKMeans.BATCHES => decayFactor
      case StreamingKMeans.POINTS =>
        val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
          n
        }.sum
        math.pow(decayFactor, numNewPoints)
    }

    // apply discount to weights
    BLAS.scal(discount, Vectors.dense(clusterWeights))

    // implement update rule
    pointStats.foreach { case (label, (sum, count)) =>
      val centroid = clusterCenters(label)

      val updatedWeight = clusterWeights(label) + count
      val lambda = count / math.max(updatedWeight, 1e-16)

      clusterWeights(label) = updatedWeight
      BLAS.scal(1.0 - lambda, centroid)
      BLAS.axpy(lambda / count, sum, centroid)

      // display the updated cluster centers
      val display = clusterCenters(label).size match {
        case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
        case _ => centroid.toArray.mkString("[", ",", "]")
      }

      logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
    }

    // Check whether the smallest cluster is dying. If so, split the largest cluster.
    val weightsWithIndex = clusterWeights.view.zipWithIndex
    val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
    val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
    if (minWeight < 1e-8 * maxWeight) {
      logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
      val weight = (maxWeight + minWeight) / 2.0
      clusterWeights(largest) = weight
      clusterWeights(smallest) = weight
      val largestClusterCenter = clusterCenters(largest)
      val smallestClusterCenter = clusterCenters(smallest)
      var j = 0
      while (j < dim) {
        val x = largestClusterCenter(j)
        val p = 1e-14 * math.max(math.abs(x), 1.0)
        largestClusterCenter.asBreeze(j) = x + p
        smallestClusterCenter.asBreeze(j) = x - p
        j += 1
      }
    }

    new StreamingKMeansModel(clusterCenters, clusterWeights)
  }

这段源码有点长,但是我们一步步来应该没问题,接下来我还是会分成几个部分来讲,大家在这里可以先休息一下准备好了再继续看。

3.3.1 计算部分
	// find nearest cluster to each point
    val closest = data.map(point => (this.predict(point), (point, 1L)))

    // get sums and counts for updating each cluster
    val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
      BLAS.axpy(1.0, p2._1, p1._1)
      (p1._1, p1._2 + p2._2)
    }
    val dim = clusterCenters(0).size

    val pointStats: Array[(Int, (Vector, Long))] = closest
      .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
      .collect()

    val discount = timeUnit match {
      case StreamingKMeans.BATCHES => decayFactor
      case StreamingKMeans.POINTS =>
        val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
          n
        }.sum
        math.pow(decayFactor, numNewPoints)
    }

    // apply discount to weights
    BLAS.scal(discount, Vectors.dense(clusterWeights))

    // implement update rule
    pointStats.foreach { case (label, (sum, count)) =>
      val centroid = clusterCenters(label)

      val updatedWeight = clusterWeights(label) + count
      val lambda = count / math.max(updatedWeight, 1e-16)

      clusterWeights(label) = updatedWeight
      BLAS.scal(1.0 - lambda, centroid)
      BLAS.axpy(lambda / count, sum, centroid)

      // display the updated cluster centers
      val display = clusterCenters(label).size match {
        case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
        case _ => centroid.toArray.mkString("[", ",", "]")
      }

      logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
    }

以上这段代码是更新簇心和簇权重的部分,对应源码的82~125行。还是很长,我们一行行看:

81~82行
	// find nearest cluster to each point
    val closest = data.map(point => (this.predict(point), (point, 1L)))

注释说这行代码是计算每个点最近的簇,注意data:RDD[Vector],通过map方法对其中的样本点point进行计算,调用StreamingKMeansModelpredict方法。

  • predict方法返回的类型是Int,即簇的ID。

那么这行代码closest类型是:RDD[(Int,(Vector,Long))],即tuple里面装的是离样本最近的簇ID和另一个tuple,小tuple装的又是样本点向量和1。

84~88行
	// get sums and counts for updating each cluster
    val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
      BLAS.axpy(1.0, p2._1, p1._1)
      (p1._1, p1._2 + p2._2)
    }

这里定义了一个方法,这个方法传入两个tuple最终返回一个tuple,而里面做的操作是

/**
   * y += a * x
   */
  def axpy(a: Double, x: Vector, y: Vector): Unit = {

将p1点的向量与p2向量相加,并且把两个tuple的第二个元素相加,一起返回。

89行
    val dim = clusterCenters(0).size

得到簇的维度,之前我们new对象的时候是设置了维度的:

.setRandomCenters(3, 0.0)  // dim=3 weight=0.0

所以在model初始化的时候会帮我们生成随机的簇心Array和簇权重Array。这里取出的就是dim

91~93行
val pointStats: Array[(Int, (Vector, Long))] = closest
      .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
      .collect()

closest调用了aggregateByKey,这个方法不了解的可以看Spark算子之aggregateByKey详解。简单来说,就是通过mergeContribs函数把closest的元素拿出来一个个相加,并且加之前还有一个零向量先跟第一个元素相加。当然它是有groupByKey的功能,即它不是乱加的啊!是有bear来的!它根据key来区分,然后再加。

所以最终我们得到的还是一样的类型格式,但是数量已经大大减少为k个,比如:

pointStats = [(1,([0.8,0.4,0.9],5)) , (2,([1.2,1.0,8.5] , 3) ) … ]

groupByKey、reduceByKey、aggregateByKey区别

95~102行
val discount = timeUnit match {
      case StreamingKMeans.BATCHES => decayFactor
      case StreamingKMeans.POINTS =>
        val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
          n
        }.sum
        math.pow(decayFactor, numNewPoints)
    }

通过timeUnit的类型来判断是batch还是point,然后计算出discount的值。这个discount的值最后要加入计算。至于它的用途我目前还是很模糊。

104~105行
	// apply discount to weights
    BLAS.scal(discount, Vectors.dense(clusterWeights))

Vectors.dense(clusterWeights)Array对象转为Vector对象,然后跟discount相乘。即计算簇的权重。

107~125行(写在代码里)
	// implement update rule
	// 遍历每一个新加入的样本的总状态
    pointStats.foreach { case (label, (sum, count)) =>
      // 从簇心Array中取出簇心向量
      val centroid = clusterCenters(label)
	  // 从簇权重Array中取出对应簇的权重 再加上该簇新样本点的数目
      val updatedWeight = clusterWeights(label) + count
      // λ值计算
      val lambda = count / math.max(updatedWeight, 1e-16)
	  // 把该簇对应权重赋值回去
      clusterWeights(label) = updatedWeight
      // 簇心向量 = (1.0-λ)*簇心向量
      BLAS.scal(1.0 - lambda, centroid)
      // 簇心向量 += (λ/count) * sum sum是之前该簇所有样本的向量之和
      BLAS.axpy(lambda / count, sum, centroid)

      // display the updated cluster centers
      // 这里没什么用
      val display = clusterCenters(label).size match {
        case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
        case _ => centroid.toArray.mkString("[", ",", "]")
      }

      logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
    }
3.3.1阶段总结

OK,到这里更新簇心、更新权重的部分代码就讲完了。总的来说就是:

  • 计算新样本点到各个簇心的距离来划分归于哪一个簇
  • 计算各个簇新来的样本的总状态(向量之和)
  • 通过遗忘因子decayFactor计算discount
  • 通过discount来更新权重
  • 通过discount来更新簇心坐标

虽然代码的操作跟Spark官网的公式很相似,但是好像有点对不上,先继续看吧。

3.3.2 划分部分
// Check whether the smallest cluster is dying. If so, split the largest cluster.
    val weightsWithIndex = clusterWeights.view.zipWithIndex
    val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
    val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
    if (minWeight < 1e-8 * maxWeight) {
      logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
      val weight = (maxWeight + minWeight) / 2.0
      clusterWeights(largest) = weight
      clusterWeights(smallest) = weight
      val largestClusterCenter = clusterCenters(largest)
      val smallestClusterCenter = clusterCenters(smallest)
      var j = 0
      while (j < dim) {
        val x = largestClusterCenter(j)
        val p = 1e-14 * math.max(math.abs(x), 1.0)
        largestClusterCenter.asBreeze(j) = x + p
        smallestClusterCenter.asBreeze(j) = x - p
        j += 1
      }
    }

    new StreamingKMeansModel(clusterCenters, clusterWeights)
  }

注释说这部分代码是用来决定是否要把将死的簇杀掉,用最大的簇一拆为二。

我们继续!

127~130行
// Check whether the smallest cluster is dying. If so, split the largest cluster.
    val weightsWithIndex = clusterWeights.view.zipWithIndex
    val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
    val (minWeight, smallest) = weightsWithIndex.minBy(_._1)

clusterWeights:Array[Double],调用了zipWithIndex,这个方法把index和值缝合起来,形如:

[(5.47,1),(6.21,2)....]

然后通过第一个值(权重大小)取得最大权重、最小权重、最大权重下标、最小权重下标

131~146行
if (minWeight < 1e-8 * maxWeight) {
      logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
      val weight = (maxWeight + minWeight) / 2.0
      clusterWeights(largest) = weight
      clusterWeights(smallest) = weight
      val largestClusterCenter = clusterCenters(largest)
      val smallestClusterCenter = clusterCenters(smallest)
      var j = 0
      while (j < dim) {
        val x = largestClusterCenter(j)
        val p = 1e-14 * math.max(math.abs(x), 1.0)
        largestClusterCenter.asBreeze(j) = x + p
        smallestClusterCenter.asBreeze(j) = x - p
        j += 1
      }
    }

判断,当最小权重小于10的-8次方倍的最大权重时,我们就需要划分了。

  • 首先把最大最小权重取均值。
  • 将最大最小权重都赋值给簇权重Array对应的位置。
  • 取出最大最小权重对应的簇心向量。
  • 对每一维进行操作:
    • 取出最大权重向量的第j维,得x。
    • 取该维绝对值与1比较,取大值与10的-14次方相乘,得p。
    • 最大权重向量第j维赋值x+p,最小权重第j维取x-p。

**遍历完成后将死的簇已经没了,取而代之的是最大权重簇一分为二的两个簇。**同时,我们也明白了权重Array的作用:

  • 可以影响簇心位置
  • 可以判断将死簇
148行
new StreamingKMeansModel(clusterCenters, clusterWeights)

经过上面的计算,簇心和簇权重Array都有了一些变化,最后通过new把模型给重新更新。

总结

经过查看Spark官方文档,我们手算出了簇心的迭代过程,也了解到遗忘因子a的作用。然后查看Streaming K-Means源码,我们也了解到模型训练阶段的具体操作。但是,我们同时也发现代码跟官方文档的公式好像有出入,而且这个算法是如何做到分布式计算的呢?这两个问题还需要进一步探究。

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值