Spark MLlib StreamingKmeans算法源码的详细分析
参考了(http://blog.csdn.net/stevekangpei/article/details/76549267)这篇文章,并修正了其中的一些细节问题
/**
* StreamingKMeansModel extends MLlib's KMeansModel for streaming
* algorithms, so it can keep track of a continuously updated weight
* associated with each cluster, and also update the model by
* doing a single iteration of the standard k-means algorithm.
*
* The update algorithm uses the "mini-batch" KMeans rule,
* generalized to incorporate forgetfullness (i.e. decay).
* The update rule (for each cluster) is:
*
* <blockquote>
* $$
* \begin{align}
* c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\
* n_t+t &= n_t * a + m_t
* \end{align}
* $$
* </blockquote>
*
* Where c_t is the previously estimated centroid for that cluster,
* n_t is the number of points assigned to it thus far, x_t is the centroid
* estimated on the current batch, and m_t is the number of points assigned
* to that centroid in the current batch.
*
* The decay factor 'a' scales the contribution of the clusters as estimated thus far,
* by applying a as a discount weighting on the current point when evaluating
* new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids
* are determined entirely by recent data. Lower values correspond to
* more forgetting.
*
* Decay can optionally be specified by a half life and associated
* time unit. The time unit can either be a batch of data or a single
* data point. Considering data arrived at time t, the half life h is defined
* such that at time t + h the discount applied to the data from t is 0.5.
* The definition remains the same whether the time unit is given
* as batches or points.
*/
这是源码中的一段注释,翻译如下
StreamingKMeansModel继承自MLlib的KMeansModel来作为实时处理的算法。因此它可以持续的追踪着和每一个cluster关联的权重。同样的通过做一个简单的迭代来更新这个聚类的模型。
更新算法采用“mini-batch” KMeans 方法,同时也包含了消失因子(decay).
更新的法则如下。
{{{
//注意这里与上面源码中的注释有所不同,我详细阅读源码后发现分母应该如下,这个式子其实是用加权平均值来求新的中心点
c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t * a + m_t]
n_t+t = n_t * a + m_t
}}}
其中c_t 上一次迭代计算出的cluster的中心点,n_t表示的是上一次迭代后这个中心点所在的类簇中的点的个数。x_t是当前这个batch的中心点,m_t是在这个batch数据下的包含在每个中心点的周边的点的个数。
dacay因子 a 表示的是之前的数据集的遗忘因子。如果a = 1 的话,表示的是所有的batch的数据的权重都是一样的,即之前的数据不会被遗忘。如果a = 0 的话,表示的是新来的数据集完全的决定了整个的数据的聚类中心点。a越接近于0 表示遗忘的越快。
衰减(decay)也可以由半衰期来表示,半衰期需要指定一个半衰期单位:即以一批数据或一个数据点作为半衰期的单位。
考虑t时刻有一批数据到达,半衰期h指的就是:在t+h时刻,t时刻所到达的数据上应该添加的折扣(discount)是0.5
实时Kmeans聚类算法最重要的方法是实现了这个注释中的那个数学公式来进行聚类。这个算法在update方法里面。StreamingKMeans有两个类,一个是StreamingKMeansModel,另一个是StreamingKMeans。
@Since("1.2.0")
//StreamingKMeansModel类表示了当前的聚类情况
class StreamingKMeansModel @Since("1.2.0") (
@Since("1.2.0") override val clusterCenters: Array[Vector],
@Since("1.2.0") val clusterWeights: Array[Double])
//第一个参数clusterCenters表示的是数据的聚类中心点,每个元素为Vector。
//第二个参数指的是clusterWeights,表示的是当前每个聚类中心周围的点的数量(或者说是权重)
extends KMeansModel(clusterCenters) with Logging {
/**
* Perform a k-means update on a batch of data.
*/
@Since("1.2.0")
//update方法实现了每次新的数据到达时,聚类中心的更新过程,这个方法结束时将返回一个更新后的StreamingKMeansModel对象
def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
// 找到每个点所属的聚类中心,closest表示的是一个三元组(index:Int,(point:Vector,1:Long))
val closest = data.map(point => (this.predict(point), (point, 1L)))
//这里定义了一个函数,这个函数对两个(Vector, Long)类型的元组进行一个合并操作,
//返回的类型也是一个元组的类型。
val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
//这个做一个向量的加法操作,
//表示p1._1 = p1._1 + 1.0 * p2._1
BLAS.axpy(1.0, p2._1, p1._1)
(p1._1, p1._2 + p2._2)
}
//dim表示聚类中心点的维度
val dim = clusterCenters(0).size
//这里调用aggregateByKey算子,这个算子运用了刚才定义的mergeContribs
//最后返回的结果通过collect返回一个数组,每个元素的类型为(Int(表示这个所有batch的数据出现在第int个中心点附近近),(Vector, Long),
//(表示的是出现在这个中心点附近的所有的向量的矢量和,出现的向量的个数))
val pointStats: Array[(Int, (Vector, Long))] = closest
.aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
.collect()
//discount表示的是就是遗忘程度,这里进行了一个判断:如果单位是batch则discount=decayFactor(我们设定的遗忘因子)
//如果单位是point则discount=decayFactor的(当前所有新来的数据点的数量)次方
val discount = timeUnit match {
case StreamingKMeans.BATCHES => decayFactor
case StreamingKMeans.POINTS =>
val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
n
}.sum
math.pow(decayFactor, numNewPoints)
}
// 将discount应用于clusterWeights所构成的向量
BLAS.scal(discount, Vectors.dense(clusterWeights))
//再把上面的公式回顾一下
//c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t * a + m_t]
//n_t+t = n_t * a + m_t
//这一步就相当于n_t * a
// implement update rule
pointStats.foreach { case (label, (sum, count)) =>
val centroid = clusterCenters(label)
//updatedWeight = n_t * a + m_t
val updatedWeight = clusterWeights(label) + count
// lambda = m_t / n_t * a + m_t
val lambda = count / math.max(updatedWeight, 1e-16)
//将clusterWeights更新为 n_t * a + m_t
clusterWeights(label) = updatedWeight
//1.0 - lambda = n_t * a / n_t * a + m_t
//c_t' = (c_t * n_t * a)/ n_t * a + m_t
BLAS.scal(1.0 - lambda, centroid)
//sum = x_t * m_t(矩阵乘法)
// lambda / count = 1 / n_t * a + m_t
//c_t'' = [(x_t * m_t) / n_t * a + m_t] + c_t
//即c_t'' = [(c_t * n_t * a) + (x_t * m_t)] / [n_t * a + m_t]
BLAS.axpy(lambda / count, sum, centroid)
// 显示出更新后的聚类中心,如果维度大于100,则显示前100个维度的内容
val display = clusterCenters(label).size match {
case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
case _ => centroid.toArray.mkString("[", ",", "]")
}
logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
}
//接下来判断有没有哪个点正在消失,如果消失的话,那么将最大的cluster进行分开。
//下面三行表示将每个数据集和它出现的索引zip起来。然后找到最大的和最小的数据聚类。
val weightsWithIndex = clusterWeights.view.zipWithIndex
val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
//如果有一个聚类太小的话,将大数据聚类进行拆分,然后重新分给小数据聚类和大数据聚类。
if (minWeight < 1e-8 * maxWeight) {
logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
val weight = (maxWeight + minWeight) / 2.0
clusterWeights(largest) = weight
clusterWeights(smallest) = weight
val largestClusterCenter = clusterCenters(largest)
val smallestClusterCenter = clusterCenters(smallest)
var j = 0
while (j < dim) {
val x = largestClusterCenter(j)
val p = 1e-14 * math.max(math.abs(x), 1.0)
largestClusterCenter.asBreeze(j) = x + p
smallestClusterCenter.asBreeze(j) = x - p
j += 1
}
}
//用更新后的clusterCenters和clusterWeights创建一个新的StreamingKMeansModel对象并返回
new StreamingKMeansModel(clusterCenters, clusterWeights)
}
}