Streaming是怎么做KMeans的?
大家好,我是一拳就能打爆你A柱的男人
大家在学机器学习的时候一定看过K-Means算法,但是各位有没有想过在实时计算的时候是如何做K-Means的呢?接下来我打算从下面几个方面来给大家梳理一下:1、K-Means算法原理,2、Streaming K-Means手算,3、Streaming K-Means源码解读。
关键词:StreamingKmeans,实时KMeans,源码解读StreamingKmeans
1、 K-Means算法原理
关于K-Means算法我之前有一篇博客也讲过,并且附带案例。各位有兴趣的可以去看一看:K-Means算法及相关案例 。接下来我还是简单讲一下吧。
首先请看下面这张图:
图1来源:KMeans 算法(一)
- a:显然有两个簇的样本,但是此时并没有做出聚类。
- b:随机生成了红蓝两个点(即簇心)
- c:此时开始迭代(c~f步)每个样本都会分别计算样本自身到两个簇心的距离,并将自身归类为最近簇心一类。
- d:在c中样本分成红蓝两色,但是显然没有分类完全,所以需要更新簇心位置,簇心位置由该类中所有样本的位置(向量)决定。
- e:经过d的簇心位置计算,样本已经很靠近自身簇的样本点,此时再次更新簇心。
- f:尽管e中在我们看来分类完成,但是并没有达到所设定的阈值,所以还需要进一步计算簇心、更新点类别… 当达到阈值时就如f所看到的。
OK,经过上面简单的讲解,大家应该有所理解K-Means算法了,总结下来就是:计算样本到簇心的距离更新样本的簇、通过样本位置更新簇心、判断阈值、再次更新下一轮。
2、 Streaming K-Means手算
2.1 Spark Streaming中K-Means的介绍
在Streaming k-means 中对实时计算K-Means有介绍:
c t + 1 = c t n t α + x t m t n t α + m t c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} ct+1=ntα+mtctntα+xtmt
n t + 1 = n t + m t n_{t+1} = n_t+m_t nt+1=nt+mt
ct
是t代簇心,是一个向量。nt
是属于t代簇的样本数,是Int。xt
是新batch的簇心,是一个向量。mt
是新batch中被加入ct
簇的样本数,是Int。α
是遗忘因子,范围a∈(0~1)
。
可能上面这个公式你会有很多疑问,为什么有a
、Kmeans不是用欧式距离更新簇心的吗?下面我将手算给大家看。
2.2 手算Streaming K-Means
现在假设是第0代,k=2,所以需要随机初始化簇心,并且加入一个新的batch来迭代。我将用a=0
、a=1
、a=0.5
三个参数来帮助大家理解a
遗忘因子,数据维度是2。
2.2.1 明确参数
c_t
作为随机初始化的簇心,我为了方便定[[0,0],[10,10]]
。batch
作为新到达的一批样本点,我定[[0,1],[1,0],[2,1],[9,9]]
。a
分别取a=0
、a=1
、a=0.5
。
2.2.2 初始图
蓝色是初始化簇心,红色为新batch
的样本点。
2.2.3 迭代计算
在新batch
进入的时候首先会计算离自身最近的簇心,然后分别定nt
、xt
、mt
这些参数。我们以[0,0]
这个簇为例,nt=0
、xt=[1,2/3]
、mt=3
。
2.2.3.1 第一次迭代a=0
代入
c
t
+
1
=
c
t
n
t
α
+
x
t
m
t
n
t
α
+
m
t
c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t}
ct+1=ntα+mtctntα+xtmt
知道
c
t
+
1
=
[
0
,
0
]
∗
0
∗
0
+
[
1
,
2
/
3
]
∗
3
0
∗
0
+
3
=
[
1
,
2
/
3
]
c_{t+1} =\frac{[0,0]*0*0+[1,2/3]*3}{0*0+3} = [1,2/3]
ct+1=0∗0+3[0,0]∗0∗0+[1,2/3]∗3=[1,2/3]
2.2.3.2 第一次迭代a=0.5
代入
c
t
+
1
=
c
t
n
t
α
+
x
t
m
t
n
t
α
+
m
t
c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t}
ct+1=ntα+mtctntα+xtmt
知道
c
t
+
1
=
[
0
,
0
]
∗
0
∗
0.5
+
[
1
,
2
/
3
]
∗
3
0
∗
0.5
+
3
=
[
1
,
2
/
3
]
c_{t+1} = \frac{[0,0]*0*0.5+[1,2/3]*3}{0*0.5+3} = [1,2/3]
ct+1=0∗0.5+3[0,0]∗0∗0.5+[1,2/3]∗3=[1,2/3]
2.2.3.3 第一次迭代a=1
c t + 1 = [ 0 , 0 ] ∗ 0 ∗ 1 + [ 1 , 2 / 3 ] ∗ 3 0 ∗ 1 + 3 = [ 1 , 2 / 3 ] c_{t+1} = \frac{[0,0]*0*1+[1,2/3]*3}{0*1+3} = [1,2/3] ct+1=0∗1+3[0,0]∗0∗1+[1,2/3]∗3=[1,2/3]
此时新簇心好像都一样,但是nt
什么的已经改变,再来一个batch
看看,新batch
我定[[-1,0],[0,-1],[11,9]]
,跟上一次迭代一样,首先通过距离计算得知左下角的簇新加入的是两个黄色的样本点。
再次重申一下参数,记得算一下n_{t+1}
:
c_{t+1}=[1,2/3]
n_{t+1}=3
x_{t+1}=[-0.5,-0.5]
m_{t+1}=2
2.2.3.4 第二次迭代a=0
代入公式:
c
t
+
2
=
c
t
+
1
n
t
+
1
α
+
x
t
+
1
m
t
+
1
n
t
+
1
α
+
m
t
+
1
=
[
1
,
2
/
3
]
∗
3
∗
0
+
[
−
0.5
,
−
0.5
]
∗
2
3
∗
0
+
2
=
[
−
0.5
,
−
0.5
]
c_{t+2} = \frac{c_{t+1}n_{t+1}\alpha + x_{t+1}m_{t+1}}{n_{t+1}\alpha+m_{t+1}} = \frac{[1,2/3]*3*0+[-0.5,-0.5]*2}{3*0+2} = [-0.5,-0.5]
ct+2=nt+1α+mt+1ct+1nt+1α+xt+1mt+1=3∗0+2[1,2/3]∗3∗0+[−0.5,−0.5]∗2=[−0.5,−0.5]
2.2.3.5 第二次迭代a=0.5
代入公式:
c
t
+
2
=
c
t
+
1
n
t
+
1
α
+
x
t
+
1
m
t
+
1
n
t
+
1
α
+
m
t
+
1
=
[
1
,
2
/
3
]
∗
3
∗
0.5
+
[
−
0.5
,
−
0.5
]
∗
2
3
∗
0.5
+
2
=
[
1
7
,
0
]
c_{t+2} = \frac{c_{t+1}n_{t+1}\alpha + x_{t+1}m_{t+1}}{n_{t+1}\alpha+m_{t+1}} = \frac{[1,2/3]*3*0.5+[-0.5,-0.5]*2}{3*0.5+2} = [\frac{1}{7},0]
ct+2=nt+1α+mt+1ct+1nt+1α+xt+1mt+1=3∗0.5+2[1,2/3]∗3∗0.5+[−0.5,−0.5]∗2=[71,0]
2.2.3.6 第二次迭代a=1
代入公式:
c
t
+
2
=
c
t
+
1
n
t
+
1
α
+
x
t
+
1
m
t
+
1
n
t
+
1
α
+
m
t
+
1
=
[
1
,
2
/
3
]
∗
3
∗
1
+
[
−
0.5
,
−
0.5
]
∗
2
3
∗
1
+
2
=
[
2
5
,
1
5
]
c_{t+2} = \frac{c_{t+1}n_{t+1}\alpha + x_{t+1}m_{t+1}}{n_{t+1}\alpha+m_{t+1}} = \frac{[1,2/3]*3*1+[-0.5,-0.5]*2}{3*1+2} = [\frac{2}{5},\frac{1}{5}]
ct+2=nt+1α+mt+1ct+1nt+1α+xt+1mt+1=3∗1+2[1,2/3]∗3∗1+[−0.5,−0.5]∗2=[52,51]
OK,现在我们可以看出差距了。a
作为遗忘因子在更新簇心的过程中起到至关重要的作用。当a=0
时,簇是完全无记忆的,每次更新新簇心都由新batch
的点来决定;当a=1
时,簇是完全无法遗忘的,每次更新由新batch
的点和以往所有点决定;当a∈(0,1)
时,簇有部分遗忘的能力,且只对以往所有点遗忘,a
越接近0记忆越差。
总结
综上,Spark官方文档关于Streaming K-Means的解释我们已经了解,它类似而又不同于单机离线Kmeans。Streaming K-Means更新由每一个batch
的样本点决定,且以往的数据最终只表现在簇心,而单机离线Kmeans没有遗忘因子,每一次迭代虽然会改变簇心的位置,但是同样也会被所有样本左右。
无论
3、Streaming K-Means源码解读
接下来就到最难啃的阶段,我们尝试去看一看StreamingKMeans
的源码,看他在代码里是如何操作的,一方面来印证之前的想法是否正确,另一方面也希望学习到其他新的东西。
在这里我们依旧使用Spark官方的示例:
import org.apache.spark.mllib.clustering.StreamingKMeans
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.streaming.{Seconds, StreamingContext}
val conf = new SparkConf().setAppName("StreamingKMeansExample")
val ssc = new StreamingContext(conf, Seconds(args(2).toLong))
val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse)
val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
val model = new StreamingKMeans()
.setK(args(3).toInt)
.setDecayFactor(1.0)
.setRandomCenters(args(4).toInt, 0.0)
model.trainOn(trainingData)
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
ssc.start()
ssc.awaitTermination()
由于我直到现在还是没法正常运行这个示例,所以没法用断点的方式查看内部的实际运行状态以及数据类型,所以我是通过纯用脑的方式来“跑”这份代码的。可能有些问题,大家有什么问题也可以在文章下留言。
第一第二行的conf
和ssc
都不用说了,配置一些环境。
3.1 读取数据
接下来
val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse)
val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
显然这两行代码是读取数据,但是具体什么格式的数据,并且map后会变成什么样呢?
在scala示例旁边还有个python示例,如图:
点击可以看到:
trainingData = sc.textFile("data/mllib/kmeans_data.txt").map(lambda line:Vectors.dense([float(x) for x in line.strip().split(' ')]))
testingData = sc.textFile("data/mllib/streaming_kmeans_data_test.txt").map(parse)
显然这就是数据,通过github上spark里查找可以得到数据格式如下:
# kmeans_data.txt
0.0 0.0 0.0
0.1 0.1 0.1
0.2 0.2 0.2
9.0 9.0 9.0
9.1 9.1 9.1
9.2 9.2 9.2
# streaming_kmeans_data_test.txt
(1.0), [1.7, 0.4, 0.9]
(2.0), [2.2, 1.8, 0.0]
其实我在这里也花了不少时间看了源码,内部还是有一些差异,但是为了保持文章的内容连贯性我直接给大家说结论:
ssc
逐行读取数据(形如kmeans_data.txt),然后通过map
方法对其中数据转成Vector对象,而这个对象打印出来是[0.0,0.0,0.0]
这样的。- 同理,
ssc
读取streaming_kmeans_data_test.txt数据,将其转为LabeledPoint
对象,这个对象有两个变量,一个是label:Double
,另一个是features:Vector
。
所以读取数据就是将特定的数据格式转为mllib
中的对象,方便接下来的操作。
3.2 生成模型
val model = new StreamingKMeans()
.setK(args(3).toInt)
.setDecayFactor(1.0)
.setRandomCenters(args(4).toInt, 0.0)
这段代码显然是设置一些相关的参数了,比如k
、a
、dim
、weight
。可以理解成手算过程中第0次迭代。
当然,为了接下来更好地理解,我给大家看一下StreamingKMeans
对象。
3.2.1 StreamingKMeans 对象
@Since("1.2.0")
class StreamingKMeans @Since("1.2.0") (
@Since("1.2.0") var k: Int,
@Since("1.2.0") var decayFactor: Double,
@Since("1.2.0") var timeUnit: String) extends Logging with Serializable {
@Since("1.2.0")
def this() = this(2, 1.0, StreamingKMeans.BATCHES)
protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
可以看到StreamingKMeans
有三个参数k
、decayFactor
、timeUnit
,并且内部还有一个StreamingKMeansModel
对象,这显然不是一样的东西。
经过简单的查看,这个类里都是一些设置参数的方法,**最重要的就是trainOn
和predictOn
方法,而这些方法无一例外的都调用了StreamingKMeansModel
对象的update
和predict
方法。**可见,真正的实现都放在StreamingKMeansModel
里。
3.2.2 StreamingKMeansModel 对象
直接点开trainOn方法就可以到达StreamingKMeans
调用StreamingKMeansModel
的地方:
model.trainOn(trainingData)
/**
* Update the clustering model by training on batches of data from a DStream.
* This operation registers a DStream for training the model,
* checks whether the cluster centers have been initialized,
* and updates the model using each batch of data from the stream.
*
* @param data DStream containing vector data
*/
@Since("1.2.0")
def trainOn(data: DStream[Vector]) {
assertInitialized()
data.foreachRDD { (rdd, time) =>
model = model.update(rdd, decayFactor, timeUnit)
}
}
很熟悉的,这里data
我们终于可以看到其类型:data: DStream[Vector]
,那么刚才读取数据部分说的没错了。经过断言检查,DStream
使用foreachRDD
操作,显然在这里我们可以知道数据在不同的分区做统一的操作。
而每一个RDD
都使用的是StreamingKMeansModel
的update
方法,点进去就可以到达具体实现了,这里我先不贴这个方法的代码,我们还是先看一下StreamingKMeansModel
对象有什么内容:
@Since("1.2.0")
class StreamingKMeansModel @Since("1.2.0") (
@Since("1.2.0") override val clusterCenters: Array[Vector],
@Since("1.2.0") val clusterWeights: Array[Double])
extends KMeansModel(clusterCenters) with Logging {
很惊讶,只有两个变量,一个是clusterCenters: Array[Vector]
,一个是clusterWeights: Array[Double]
,经过观察名称可以知道其含义:
clusterCenters: Array[Vector]
:这个变量用Array
来装,里面全是Vector
,而且名字叫clusterCenters
那么显然这个变量就是装的簇心向量。所以说每一次迭代都会改变这个簇心Array
的内容。clusterWeights: Array[Double]
:这个变量也是Array
装,且里面是Double
类型,名字叫簇权重,目前还是很模糊。
3.3 StreamingKMeansModel
的update
方法
看完StreamingKMeansModel
对象,我们直接看update
,给大家贴源码(很长,可以跳过源码看我讲解):
/**
* Perform a k-means update on a batch of data.
*/
@Since("1.2.0")
def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
// find nearest cluster to each point
val closest = data.map(point => (this.predict(point), (point, 1L)))
// get sums and counts for updating each cluster
val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
BLAS.axpy(1.0, p2._1, p1._1)
(p1._1, p1._2 + p2._2)
}
val dim = clusterCenters(0).size
val pointStats: Array[(Int, (Vector, Long))] = closest
.aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
.collect()
val discount = timeUnit match {
case StreamingKMeans.BATCHES => decayFactor
case StreamingKMeans.POINTS =>
val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
n
}.sum
math.pow(decayFactor, numNewPoints)
}
// apply discount to weights
BLAS.scal(discount, Vectors.dense(clusterWeights))
// implement update rule
pointStats.foreach { case (label, (sum, count)) =>
val centroid = clusterCenters(label)
val updatedWeight = clusterWeights(label) + count
val lambda = count / math.max(updatedWeight, 1e-16)
clusterWeights(label) = updatedWeight
BLAS.scal(1.0 - lambda, centroid)
BLAS.axpy(lambda / count, sum, centroid)
// display the updated cluster centers
val display = clusterCenters(label).size match {
case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
case _ => centroid.toArray.mkString("[", ",", "]")
}
logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
}
// Check whether the smallest cluster is dying. If so, split the largest cluster.
val weightsWithIndex = clusterWeights.view.zipWithIndex
val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
if (minWeight < 1e-8 * maxWeight) {
logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
val weight = (maxWeight + minWeight) / 2.0
clusterWeights(largest) = weight
clusterWeights(smallest) = weight
val largestClusterCenter = clusterCenters(largest)
val smallestClusterCenter = clusterCenters(smallest)
var j = 0
while (j < dim) {
val x = largestClusterCenter(j)
val p = 1e-14 * math.max(math.abs(x), 1.0)
largestClusterCenter.asBreeze(j) = x + p
smallestClusterCenter.asBreeze(j) = x - p
j += 1
}
}
new StreamingKMeansModel(clusterCenters, clusterWeights)
}
这段源码有点长,但是我们一步步来应该没问题,接下来我还是会分成几个部分来讲,大家在这里可以先休息一下准备好了再继续看。
3.3.1 计算部分
// find nearest cluster to each point
val closest = data.map(point => (this.predict(point), (point, 1L)))
// get sums and counts for updating each cluster
val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
BLAS.axpy(1.0, p2._1, p1._1)
(p1._1, p1._2 + p2._2)
}
val dim = clusterCenters(0).size
val pointStats: Array[(Int, (Vector, Long))] = closest
.aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
.collect()
val discount = timeUnit match {
case StreamingKMeans.BATCHES => decayFactor
case StreamingKMeans.POINTS =>
val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
n
}.sum
math.pow(decayFactor, numNewPoints)
}
// apply discount to weights
BLAS.scal(discount, Vectors.dense(clusterWeights))
// implement update rule
pointStats.foreach { case (label, (sum, count)) =>
val centroid = clusterCenters(label)
val updatedWeight = clusterWeights(label) + count
val lambda = count / math.max(updatedWeight, 1e-16)
clusterWeights(label) = updatedWeight
BLAS.scal(1.0 - lambda, centroid)
BLAS.axpy(lambda / count, sum, centroid)
// display the updated cluster centers
val display = clusterCenters(label).size match {
case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
case _ => centroid.toArray.mkString("[", ",", "]")
}
logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
}
以上这段代码是更新簇心和簇权重的部分,对应源码的82~125行。还是很长,我们一行行看:
81~82行
// find nearest cluster to each point
val closest = data.map(point => (this.predict(point), (point, 1L)))
注释说这行代码是计算每个点最近的簇,注意data:RDD[Vector]
,通过map
方法对其中的样本点point
进行计算,调用StreamingKMeansModel
的predict
方法。
predict
方法返回的类型是Int
,即簇的ID。
那么这行代码closest
类型是:RDD[(Int,(Vector,Long))]
,即tuple
里面装的是离样本最近的簇ID和另一个tuple
,小tuple
装的又是样本点向量和1。
84~88行
// get sums and counts for updating each cluster
val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
BLAS.axpy(1.0, p2._1, p1._1)
(p1._1, p1._2 + p2._2)
}
这里定义了一个方法,这个方法传入两个tuple
最终返回一个tuple
,而里面做的操作是
/**
* y += a * x
*/
def axpy(a: Double, x: Vector, y: Vector): Unit = {
将p1点的向量与p2向量相加,并且把两个tuple
的第二个元素相加,一起返回。
89行
val dim = clusterCenters(0).size
得到簇的维度,之前我们new对象的时候是设置了维度的:
.setRandomCenters(3, 0.0) // dim=3 weight=0.0
所以在model
初始化的时候会帮我们生成随机的簇心Array
和簇权重Array
。这里取出的就是dim
91~93行
val pointStats: Array[(Int, (Vector, Long))] = closest
.aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
.collect()
closest
调用了aggregateByKey
,这个方法不了解的可以看Spark算子之aggregateByKey详解。简单来说,就是通过mergeContribs
函数把closest
的元素拿出来一个个相加,并且加之前还有一个零向量先跟第一个元素相加。当然它是有groupByKey
的功能,即它不是乱加的啊!是有bear来的!它根据key来区分,然后再加。
所以最终我们得到的还是一样的类型格式,但是数量已经大大减少为k
个,比如:
pointStats
= [(1,([0.8,0.4,0.9],5)) , (2,([1.2,1.0,8.5] , 3) ) … ]
groupByKey、reduceByKey、aggregateByKey区别
95~102行
val discount = timeUnit match {
case StreamingKMeans.BATCHES => decayFactor
case StreamingKMeans.POINTS =>
val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
n
}.sum
math.pow(decayFactor, numNewPoints)
}
通过timeUnit
的类型来判断是batch
还是point
,然后计算出discount
的值。这个discount
的值最后要加入计算。至于它的用途我目前还是很模糊。
104~105行
// apply discount to weights
BLAS.scal(discount, Vectors.dense(clusterWeights))
Vectors.dense(clusterWeights)
把Array
对象转为Vector
对象,然后跟discount
相乘。即计算簇的权重。
107~125行(写在代码里)
// implement update rule
// 遍历每一个新加入的样本的总状态
pointStats.foreach { case (label, (sum, count)) =>
// 从簇心Array中取出簇心向量
val centroid = clusterCenters(label)
// 从簇权重Array中取出对应簇的权重 再加上该簇新样本点的数目
val updatedWeight = clusterWeights(label) + count
// λ值计算
val lambda = count / math.max(updatedWeight, 1e-16)
// 把该簇对应权重赋值回去
clusterWeights(label) = updatedWeight
// 簇心向量 = (1.0-λ)*簇心向量
BLAS.scal(1.0 - lambda, centroid)
// 簇心向量 += (λ/count) * sum sum是之前该簇所有样本的向量之和
BLAS.axpy(lambda / count, sum, centroid)
// display the updated cluster centers
// 这里没什么用
val display = clusterCenters(label).size match {
case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
case _ => centroid.toArray.mkString("[", ",", "]")
}
logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
}
3.3.1阶段总结
OK,到这里更新簇心、更新权重的部分代码就讲完了。总的来说就是:
- 计算新样本点到各个簇心的距离来划分归于哪一个簇
- 计算各个簇新来的样本的总状态(向量之和)
- 通过遗忘因子decayFactor计算discount
- 通过discount来更新权重
- 通过discount来更新簇心坐标
虽然代码的操作跟Spark官网的公式很相似,但是好像有点对不上,先继续看吧。
3.3.2 划分部分
// Check whether the smallest cluster is dying. If so, split the largest cluster.
val weightsWithIndex = clusterWeights.view.zipWithIndex
val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
if (minWeight < 1e-8 * maxWeight) {
logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
val weight = (maxWeight + minWeight) / 2.0
clusterWeights(largest) = weight
clusterWeights(smallest) = weight
val largestClusterCenter = clusterCenters(largest)
val smallestClusterCenter = clusterCenters(smallest)
var j = 0
while (j < dim) {
val x = largestClusterCenter(j)
val p = 1e-14 * math.max(math.abs(x), 1.0)
largestClusterCenter.asBreeze(j) = x + p
smallestClusterCenter.asBreeze(j) = x - p
j += 1
}
}
new StreamingKMeansModel(clusterCenters, clusterWeights)
}
注释说这部分代码是用来决定是否要把将死的簇杀掉,用最大的簇一拆为二。
我们继续!
127~130行
// Check whether the smallest cluster is dying. If so, split the largest cluster.
val weightsWithIndex = clusterWeights.view.zipWithIndex
val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
clusterWeights:Array[Double]
,调用了zipWithIndex
,这个方法把index
和值缝合起来,形如:
[(5.47,1),(6.21,2)....]
然后通过第一个值(权重大小)取得最大权重、最小权重、最大权重下标、最小权重下标。
131~146行
if (minWeight < 1e-8 * maxWeight) {
logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
val weight = (maxWeight + minWeight) / 2.0
clusterWeights(largest) = weight
clusterWeights(smallest) = weight
val largestClusterCenter = clusterCenters(largest)
val smallestClusterCenter = clusterCenters(smallest)
var j = 0
while (j < dim) {
val x = largestClusterCenter(j)
val p = 1e-14 * math.max(math.abs(x), 1.0)
largestClusterCenter.asBreeze(j) = x + p
smallestClusterCenter.asBreeze(j) = x - p
j += 1
}
}
判断,当最小权重小于10的-8次方倍的最大权重时,我们就需要划分了。
- 首先把最大最小权重取均值。
- 将最大最小权重都赋值给簇权重Array对应的位置。
- 取出最大最小权重对应的簇心向量。
- 对每一维进行操作:
- 取出最大权重向量的第
j
维,得x。 - 取该维绝对值与1比较,取大值与10的-14次方相乘,得p。
- 最大权重向量第
j
维赋值x+p,最小权重第j
维取x-p。
- 取出最大权重向量的第
**遍历完成后将死的簇已经没了,取而代之的是最大权重簇一分为二的两个簇。**同时,我们也明白了权重Array的作用:
- 可以影响簇心位置
- 可以判断将死簇
148行
new StreamingKMeansModel(clusterCenters, clusterWeights)
经过上面的计算,簇心和簇权重Array都有了一些变化,最后通过new把模型给重新更新。
总结
经过查看Spark官方文档,我们手算出了簇心的迭代过程,也了解到遗忘因子a
的作用。然后查看Streaming K-Means
源码,我们也了解到模型训练阶段的具体操作。但是,我们同时也发现代码跟官方文档的公式好像有出入,而且这个算法是如何做到分布式计算的呢?这两个问题还需要进一步探究。