StreamingKMeans核心源码解读 流式KMeans核心源码解读

StreamingKMeans核心源码解读 流式KMeans核心源码解读

大家好,我是一拳就能打爆A柱的硬核男人

之前给大家翻过流式算法的继承树,而且对于每一部分组件的内容、职责都有了一点了解,其实Spark流式算法的大致结构都差不多,所以这里也不给大家翻继承树了,直接上核心部分的代码一行行的分析。接下来我会先介绍方法入口,方便大家打开IDE跟着博客一起看,同时我会以行号加粗的方式标志代码,下方配上分析,希望各位能习惯。(建议一起打开IDE源码交叉看,当然我也会把代码贴出来。)

1、 方法入口

在Spark官网有关于StreamingKMeans的介绍,其中附上了小案例:

import org.apache.spark.mllib.clustering.StreamingKMeans
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.streaming.{Seconds, StreamingContext}

val conf = new SparkConf().setAppName("StreamingKMeansExample")
val ssc = new StreamingContext(conf, Seconds(args(2).toLong))

val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse)
val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
// new streamingKMeans对象
val model = new StreamingKMeans()
  .setK(args(3).toInt)
  .setDecayFactor(1.0)
  .setRandomCenters(args(4).toInt, 0.0)
// 训练
model.trainOn(trainingData)
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()

ssc.start()
ssc.awaitTermination()

可以看到model.trainOn就是我们要找的入口,trainOn方法如下:

def trainOn(data: DStream[Vector]): Unit = {
    assertInitialized()
    data.foreachRDD { (rdd, time) =>
        model = model.update(rdd, decayFactor, timeUnit)
    }
}

可以看到DStream使用foreachRDD方法,对每个时间段产生的RDD做更新操作(model.update),最后就是一个新的model出现在Driver端了!所以update方法就是核心代码位置:

def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {

    // find nearest cluster to each point
    val closest = data.map(point => (this.predict(point), (point, 1L)))

    // get sums and counts for updating each cluster
    val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
        BLAS.axpy(1.0, p2._1, p1._1)
        (p1._1, p1._2 + p2._2)
    }
    val dim = clusterCenters(0).size

    val pointStats: Array[(Int, (Vector, Long))] = closest
    .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
    .collect()

    val discount = timeUnit match {
        case StreamingKMeans.BATCHES => decayFactor
        case StreamingKMeans.POINTS =>
        val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
            n
        }.sum
        math.pow(decayFactor, numNewPoints)
    }

    // apply discount to weights
    BLAS.scal(discount, Vectors.dense(clusterWeights))

    // implement update rule
    pointStats.foreach { case (label, (sum, count)) =>
        val centroid = clusterCenters(label)

        val updatedWeight = clusterWeights(label) + count
        val lambda = count / math.max(updatedWeight, 1e-16)

        clusterWeights(label) = updatedWeight
        BLAS.scal(1.0 - lambda, centroid)
        BLAS.axpy(lambda / count, sum, centroid)

        // display the updated cluster centers
        val display = clusterCenters(label).size match {
            case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
            case _ => centroid.toArray.mkString("[", ",", "]")
        }

        logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
    }

    // Check whether the smallest cluster is dying. If so, split the largest cluster.
    val weightsWithIndex = clusterWeights.view.zipWithIndex
    val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
    val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
    if (minWeight < 1e-8 * maxWeight) {
        logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
        val weight = (maxWeight + minWeight) / 2.0
        clusterWeights(largest) = weight
        clusterWeights(smallest) = weight
        val largestClusterCenter = clusterCenters(largest)
        val smallestClusterCenter = clusterCenters(smallest)
        var j = 0
        while (j < dim) {
            val x = largestClusterCenter(j)
            val p = 1e-14 * math.max(math.abs(x), 1.0)
            largestClusterCenter.asBreeze(j) = x + p
            smallestClusterCenter.asBreeze(j) = x - p
            j += 1
        }
    }

    new StreamingKMeansModel(clusterCenters, clusterWeights)
}

找到代码位置后先不着急看,先把官方对StreamingKmeans的思想介绍看一下:

c t + 1 = c t n t α + x t m t n t α + m t             ( 1 ) n t + 1 = n t + m t                       ( 2 ) c_{t+1}=\frac{c_tn_tα+x_tm_t}{n_tα+m_t} \ \ \ \ \ \ \ \ \ \ \ (1) \\n_{t+1}=n_t+m_t \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) ct+1=ntα+mtctntα+xtmt           (1)nt+1=nt+mt                     (2)

Where ct is the previous center for the cluster, nt is the number of points assigned to the cluster thus far, xt is the new cluster center from the current batch, and mt is the number of points added to the cluster in the current batch. The decay factor α can be used to ignore the past: with α=1 all data will be used from the beginning; with α=0 only the most recent data will be used. This is analogous to an exponentially-weighted moving average.

翻译:Ct是旧簇心,Nt是目前为止被分配到该簇(Ct)的点的个数,Xt是新batch的新簇心,Mt是新batch中被添加到该簇(Ct)的点的个数。遗忘因子α可以用来遗忘过去(减少过去样本点的影响):当α=1时过去所有的样本点的影响都会用到,当α=0时只有最新的数据的影响会用到。这类似于指数加权移动平均值。

2、 逐行解读

在开始分析前,还要弄清楚update方法的参数:

def update(data: RDD[Vector], decayFactor: Double, timeUnit: String)
// data: RDD格式,内装Vector向量,注意与LabeledPoint区分,此处没有label
// decayFactor: 遗忘因子,可以给定值,当然也有构造函数可以设定
// timeUnit: 时间单元,可以是一批数据(StreamingKMeans.BATCHES)也可以是一个样本(StreamingKMeans.POINTS),用于区分到达数据的量

搞清楚参数后,接下来开始逐行分析。

81行

先用之前训练好的model来预测这批刚到达的数据,给他们分一下类,得到的格式是RDD[(Int, (Vector , Long))],每一条分别代表:(改点被分到的簇,(点向量,1))。还记得wordcount吗?后面的1L是不是就类似wordcount计数,事实上也是这样的,待会要分类累加。

85 - 88行

这是定义了一个聚类的方法,接受两个点p1,p2,点的格式都是(Vector,Long) ,然后将点的向量相加,最后返回两个点的向量值之和,两个点的Long值。从这里可以预测该方法是在定义聚类的规则

89行

取出维度长度

91 - 93行

81行的结果做聚类(aggregateByKey),aggregateByKey方法的使用大家可以去这篇博客看一看,我引用它的图来给各位解释一下:

def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U,
                                              combOp: (U, U) => U): RDD[(K, U)]
// (zeroValue: U) 零值 也是初始值
// seqOp: (U, V) => U 方法一,在分区内聚合使用的方法
// combOp: (U, U) => U 方法二,在分区间聚合使用的方法

根据上图,两个分区中有key相同的样本,分区内聚合的规则是将key相同的value取最大值,0值(初始值)为3。

对于第一个分区:(1, 1)、(1, 2) -> (1, 3) ;(2, 1) -> (2, 3) (注:对于key=1:3与1相比取3,3与2相比取3,对于key=2:1与3相比取3

对于第二个分区:(2, 3)、(2, 4) -> (2, 4) ;(1, 7) -> (1, 7)

分区间的聚合规则是将相同key的value相加:

所以shuffle,得到:(1, 3)、(1, 7) -> (1, 10) ; (2, 3)、(2, 4) -> (2 , 7)

**根据这个方法,可以将之前分类的点,根据不同的簇计算出当前批次样本在各个簇的特征总和、被分到该簇的样本数。**还记得Spark对流式算法的两个公式吗?现在正在计算出指标,此时Mt和Xt已经出来了!此时经过聚合加collect,数据已经集中在Driver端了。

可能这部分讲的有点乱,建议各位去看懂aggregateByKey方法然后仔细看85 - 88行的定义,最后再看91 - 93行

95 - 102行

这里用到了timeUnit,前面也讲过这个是用来判断本批次数据是批还是条。对于批数据,直接返回遗忘因子decayFactor(也就是公式中的α),对于一条数据,计算出所有新增点的总数,返回α^总数。

105行

BLAS.scal(discount, Vectors.dense(clusterWeights))
// clusterWeights指的是各个簇的样本数,对你没看错,虽然带了个weights但是表示的是簇的样本数!
// 还记得前面Spark对参数的介绍中有α吗?Spark使用遗忘因子α对过去的部分样本遗忘,这就是在打折,也就是现在做的事 clusterWeights = discount * clusterWeights

各位也可能会找到《【技术分享】流式k-means算法》这篇文章,其中针对这个参数的介绍说成是权重,其实是不准确的。至于为什么,大家在下面的具体计算中可以看到。

108 - 125行

pointStats.foreach { case (label, (sum, count)) => // (簇,(特征总和,样本数))
    val centroid = clusterCenters(label)

    val updatedWeight = clusterWeights(label) + count
    val lambda = count / math.max(updatedWeight, 1e-16)

    clusterWeights(label) = updatedWeight
    BLAS.scal(1.0 - lambda, centroid)
    BLAS.axpy(lambda / count, sum, centroid)

   .....
}

计算的部分保留,打log的去掉了。pointStats就是91 - 93行计算的结果,表示按簇分类,新来的数据特征总和跟样本数。

操作如下:

  • 1、取出簇心向量(之前训练出来的,通过label取出来,这个clusterCenters是由kmeansModel维护的)
  • 2、 取出每个簇心对应的样本数,加上新来的分到对应簇的样本个数。
  • 3、 求λ
  • 4、 更新簇的样本数
  • 5、 簇心向量 = λ * 簇心向量 ,更新簇心向量
  • 6、 簇心 = λ/该簇样本数 * 该簇样本特征总值

到这里关于整个model的权重计算已经结束了,下面关于死亡簇和分割最大簇的部分不是本篇重点,感兴趣的同学可以自己观看。那么现在还是先来验证为什么clusterWeights是样本数而不是什么权重。

先看下面这段代码:

BLAS.scal(discount, Vectors.dense(clusterWeights))

val updatedWeight = clusterWeights(label) + count
val lambda = count / math.max(updatedWeight, 1e-16)

clusterWeights(label) = updatedWeight

以上就是update方法中用到clusterWeights的关键地方,clusterWeights是一个val clusterWeights: Array[Double],已知label是簇id,所以簇数应当是远远小于样本特征数的,这一点就可以怀疑它不是权重。而且clusterWeights(label)取出值与count相加后得到updatedWeights被用来计算λ,最后又返回去了=。= count是簇id为label的样本个数,所以说取出来旧的簇样本数更新归类为label簇的样本个数是理所当然的。

加上前面的α遗忘因子遗忘了部分旧的样本点,所以这样的操作是符合Spark的设计思想的。

2.1 复盘公式

因为整个程序变量有点多,而且计算方式感觉与Spark官方给的公式有出入,所以对程序复盘,希望能还原出公式。

c t + 1 = c t n t α + x t m t n t α + m t             ( 1 ) n t + 1 = n t + m t                       ( 2 ) c_{t+1}=\frac{c_tn_tα+x_tm_t}{n_tα+m_t} \ \ \ \ \ \ \ \ \ \ \ (1) \\n_{t+1}=n_t+m_t \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) ct+1=ntα+mtctntα+xtmt           (1)nt+1=nt+mt                     (2)
首先找到对应的变量,在91行中已经可以得到特征向量综合跟个数(按簇分类),如下图:

在这里插入图片描述

所以pointStats._ 2._ 1 = Xt,pointStats._ 2._ 2 = Mt。

在clusterWeights中可以得到各个簇的样本总数Nt。

在clusterCenters中可以得到各个簇的簇心向量Ct。

所以所有的元素都得到了!接下来看看是如何组装的,首先看最简单的第二个公式:
n t + 1 = n t + m t                       ( 2 ) n_{t+1}=n_t+m_t \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) nt+1=nt+mt                     (2)
对应的代码就是:

val updatedWeight = clusterWeights(label) + count
clusterWeights(label) = updatedWeight

没错,取出来,加一下,返回去。

第一个公式有点复杂,一部分一部分来,首先确定α的位置:

val discount = timeUnit match {
    case StreamingKMeans.BATCHES => decayFactor
    case StreamingKMeans.POINTS =>
    val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
        n
    }.sum
    math.pow(decayFactor, numNewPoints)
}
BLAS.scal(discount, Vectors.dense(clusterWeights))

这里的discount就是α,而且对clusterWeights做了遗忘,所以现在这个Array中全都是原来的nt遗忘了一部分样本,即nt*α。接下来就不会再出现α,因为都融合入clusterWeights了。

重点在下面这个遍历,取出每一个簇,对应的特征总和(XtMt),分到簇的样本数(Mt):

pointStats.foreach { case (label, (sum/*(XtMt)*/, count/*(Mt)*/)) =>
    val centroid/*(Ct)*/ = clusterCenters(label)

    val updatedWeight/*(Nt*α + Mt)*/ = clusterWeights(label) + count
    val lambda/*(Mt/(Nt*α+Mt))*/ = count / math.max(updatedWeight, 1e-16)

    clusterWeights(label) = updatedWeight
    BLAS.scal(1.0 - lambda, centroid)// Ct =(1-λ)Ct
    BLAS.axpy(lambda / count, sum, centroid) // Ct += λ/Mt * XtMt

   ....
}

再来一遍:
下 面 的 运 算 一 步 步 算 应 该 看 得 清 楚 : s u m : X t ∗ M t c o u n t : M t c e n t r o i d : C t u p d a t e d W e i g h t : N t α + M t l a m b d a : λ = M t N t α + M t B L A S . s c a l ( 1.0 − l a m b d a , c e n t r o i d ) : C t = ( 1 − λ ) C t = ( 1 − M t N t α + M t ) C t B L A S . a x p y ( l a m b d a / c o u n t , s u m , c e n t r o i d ) : C t = C t + λ M t ∗ X t M t = ( 1 − M t N t α + M t ) C t + M t N t α + M t M t ∗ X t M t = N t α + M t − M t N t α + M t C t + X t M t N t α + M t = N t α C t + X t M t N t α + M t 下面的运算一步步算应该看得清楚:\\sum : X_t*M_t \\count : M_t \\centroid : C_t \\updatedWeight : N_tα+M_t \\lambda : λ = \frac{M_t}{N_tα+M_t} \\BLAS.scal(1.0 - lambda, centroid) : C_t = (1-λ)C_t = (1-\frac{M_t}{N_tα+M_t}) C_t \\BLAS.axpy(lambda / count, sum, centroid): C_t = C_t+\frac{λ}{M_t}*X_tM_t \\ = (1-\frac{M_t}{N_tα+M_t})C_t + \frac{\frac{M_t}{N_tα+M_t}}{M_t}*X_tM_t \\= \frac{N_tα+M_t-M_t}{N_tα+M_t}C_t + \frac{X_tM_t}{N_tα+M_t} \\= \frac{N_tαC_t+X_tM_t}{N_tα+M_t} sum:XtMtcount:Mtcentroid:CtupdatedWeight:Ntα+Mtlambda:λ=Ntα+MtMtBLAS.scal(1.0lambda,centroid):Ct=(1λ)Ct=(1Ntα+MtMt)CtBLAS.axpy(lambda/count,sum,centroid):Ct=Ct+MtλXtMt=(1Ntα+MtMt)Ct+MtNtα+MtMtXtMt=Ntα+MtNtα+MtMtCt+Ntα+MtXtMt=Ntα+MtNtαCt+XtMt
所以,可以证明Spark Streaming关于KMeans算法的计算方法是与其介绍相符的。

3、流式KMeans的思想

光是了解其内部流程还不行,主要是要看他的思想,设计者安排这样的公式体现了其对新旧数据的重视程度,最核心的公式就是公式一:
c t + 1 = c t n t α + x t m t n t α + m t             ( 1 ) c_{t+1}=\frac{c_tn_tα+x_tm_t}{n_tα+m_t} \ \ \ \ \ \ \ \ \ \ \ (1) ct+1=ntα+mtctntα+xtmt           (1)
分母通过α遗忘因子将部分历史样本遗忘后,加上新加入本簇的样本点个数。

分子右半部分表示新加入本簇的样本点的向量和,左半部分不带α,表示上一轮截止未被遗忘的所有样本点向量和,带α则表示遗忘一部分样本点后的向量和。

综上,分子是本轮迭代该保留下来的样本点的向量和,分母是该保留下来的样本点的总数。二者相除则表示那么多点的向量的均值。公式的核心部分主要集中在遗忘因子α上,α的加入可以让簇心时刻保持一个可以活动的状态,如果没有α,在累积大量的数据之后簇心受到过往数据的牵扯太大难以移动,从而失去适应能力。如果α太大则完全跳脱,来一批数据就是一个新的模型,所以α作为一个超参数需要认真斟酌。

总结

StreamingKMeans的核心源码就讲解完了,这个算法是Spark流式机器学习中最简单的一个,所以推荐大家从这个算法开始入门。Spark提供了公式给我们参考,并且从分布式的角度来说,分布式在这个算法的应有只有计算样本的分类,计算向量和,就是一个wordcount。所以也降低了阅读的难度。

其实之前我也写过一篇关于流式KMeans的博客,可是当时对各方面的知识理解太浅了,所以很多东西都说错了,这篇文章我经过反复验证应该是没问题的。
这篇文章写得还是略显杂乱,但是我认为该讲到的点都已经讲清楚了,如果大家有不理解的地方可以在博客下方留言。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值