8.SparkMLlib GMM高斯混合模型算法
8.1 GMM高斯混合模型算法
在MLlib的聚类算法中,高斯混合模型算法也是一种很重要的聚类算法,其基于单高斯模型,而两者的数学基础都是高斯分布。
在统计学中,若随机变量
X
X
X服从数学期望为
μ
μ
μ、方差为
σ
2
\sigma^2
σ2的高斯分布,则记为
N
(
μ
,
σ
2
)
N(μ,\sigma^2)
N(μ,σ2)。数学表达式如下:
f ( x ) = 1 2 π σ exp ( − ( x − μ ) 2 2 σ 2 ) f ( x ) = \frac { 1 } { \sqrt { 2 \pi } \sigma } \operatorname { exp } ( - \frac { ( x - \mu ) ^ { 2 } } { 2 \sigma ^ { 2 } } ) f(x)=2πσ1exp(−2σ2(x−μ)2)
高斯分布的期望值
μ
μ
μ和标准差
σ
\sigma
σ分别决定了分布的位置和幅度。而最为我们熟知的就是期望值为0,标准差为1的标准高斯分布——正态分布。
以高斯分布为基础的单高斯分布聚类模型,原理是根据已有数据建立一个分布模型,再向模型中代入样本数据计算其值,设定一个阈值范围,如果计算值在阈值以内,则判定该数据与高斯分布相匹配,否则认定该数据不属于该高斯模型的聚类,考虑到实际应用中,特征维度是多维的,多维高斯分布公式如下,
x
x
x为样本数据:
F ( x , μ , σ ) = 1 2 π σ exp ( − ( x − μ ) n − 1 2 σ 2 ) F ( x , \mu , \sigma ) = \frac { 1 } { \sqrt { 2 \pi } \sigma } \operatorname { exp } ( - \frac { ( x - \mu ) ^ { n - 1 } } { 2 \sigma ^ { 2 } } ) F(x,μ,σ)=2πσ1exp(−2σ2(x−μ)n−1)
混合高斯模型GMM是在单高斯模型基础上发展得到的,是一种概率式的聚类方法,属于生成式模型,它假设所有的数据样本都是由某一个给定参数的 多元高斯分布 所生成的。可以有效解决单高斯模型聚合混合数据不理想的问题。GMM的原理就是对于任何样本数据来说,都可以使用多个高斯分布模型表示,其数学表达式如下:
Pr ( x ) = ∑ π F ( x , μ , σ ) \operatorname { Pr } ( x ) = \sum \pi F ( x , \mu , \sigma ) Pr(x)=∑πF(x,μ,σ)
对于此表达式,我们要做的就是使用样本数据,通过极大似然估计法训练获得模型参数。可以通过增加GMM中模型的数量,任意的逼近任何连续的概率密度分布函数,而这个数量值的选择就是训练模型好坏的关键。
8.2 算法源码分析
(1)高斯混合伴生对象:object GaussianMixture,需要设置的参数有:
- k:聚类的数目
- d:特征数目
当k不是很小且d大于25时,对多元高斯的计算需要使用启发式分布。
(2)高斯混合类:class GaussianMixture,抽象GMM的超参数并进行训练,创建此类,使用set方法设置参数,调用fit方法训练一个GMM模型,需要设置的参数如下:
- K——聚类数目,默认为2
- maxIter——最大迭代次数,默认为100
- seed——随机数种子,默认为随机Long值
- Tol——对数似然函数收敛阈值,默认为0.01
(3)高斯混合模型:GaussianMixtureModel,所含参数主要是:
- Weights—— Array[Double]每个多元高斯分布所占权重
- gaussians:——Array[MultivariateGaussian]一个GMM由多个多元高斯分布组成含有predict方法预测样本
还有函数computeSoftAssignments,给定一个新的样本点,计算它属于各个聚类的概率
(4)模型训练
run方法:数据和参数准备好后用来训练模型;
updateWeightsAndGaussians:更新高斯模型的权重
8.3应用实战
8.3.1 数据说明
本次实战使用的数据是iris数据集,与之前教程所用的一样。
8.3.2 代码详解
//导入所需的包文件
import org.apache.spark.ml.clustering.{GaussianMixture,GaussianMixtureModel}
import org.apache.spark.ml.linalg.Vectors
//为后续RDD隐式转换做准备
import spark.implicits._
//定义case class作为生成的DataFrame中每一个数据样本的数据类型
case class model_instance (features: Vector)
输出结果:
defined class model_instance
//数据读入RDD中,并转换为DataFrame
val rawData = sc.textFile("/mnt/hgfs/thunder-download/MLlib_rep/data/iris.txt")
输出结果:
rawData: org.apache.spark.rdd.RDD[String] = iris.csv MapPartitionsRDD[48] at textFile at <console>:33
//过滤类标签
val df = rawData.map(line =>
| { model_instance( Vectors.dense(line.split(",").filter(p =>
p.matches("\\d*(\\.?)\\d*"))
| .map(_.toDouble)) )}).toDF()
//创建GaussianMixture类,设置训练所需参数,聚类数目取3
val gm = new GaussianMixture().setK(3)
|.setPredictionCol("Prediction")
|.setProbabilityCol("Probability")
输出结果:
gm: org.apache.spark.ml.clustering.GaussianMixture =
GaussianMixture_53916e2247ae
val gmm = gm.fit(df)
输出结果:
gmm: org.apache.spark.ml.clustering.GaussianMixtureModel =
GaussianMixture_53916e2247ae
这里设置Probability,相比于KMeans聚类,可以得到样本属于每个聚类的概率
//使用transform方法处理数据集并输出,可以看到每个样本的预测聚类
//和其概率分布向量
val result = gmm.transform(df)
result.show(150, false)
+-----------------+----------+------------------------------------------------------------------+
|features |Prediction|Probability
|
+-----------------+----------+------------------------------------------------------------------+
|[5.1,3.5,1.4,0.2]|0 |[0.9999999999999951,4.682229962936943E-
17,4.868372929920407E-15] |
|.................|.. |................................................................ |
|[5.6,2.8,4.9,2.0]|1 |[8.920203149708086E-
16,0.5988576194515217,0.4011423805484774] |
|.................|.. |................................................................ |
|[6.3,2.7,4.9,1.8]|2 |[5.703158630226758E-
16,0.022033640207248576,0.9779663597927509] |
+-----------------+----------+------------------------------------------------------------------+
//查看模型的相关参数,即各个混合高斯分布的参数,这里主要是各个混合成分//的权重和混合成分的均值向量和协方差矩阵
for (i <- 0 until gmm.getK) {
| println("Component %d : weight is %f \n mu vector is %s \n sigma matrix is %s" format
| (i, gmm.weights(i), gmm.gaussians(i).mean, gmm.gaussians(i).cov))
| }
输出结果如下:
Component 0 : weight is 0.333333
mu vector is
[5.006000336585284,3.41800074359835,1.4640001090120234,0.2439999627867791]
sigma matrix is 0.12176391071215485 0.09829168918600302 0.01581595534223468 0.01033602571352466
0.09829168918600302 0.14227526345684152 0.011447885703674401 0.01120804907975396
0.01581595534223468 0.011447885703674401 0.02950400173292353 0.005584009823879005
0.01033602571352466 0.01120804907975396 0.005584009823879005 0.01126400540784641
Component 1 : weight is 0.158358
mu vector is
[6.683368733405807,2.86961545411428,5.6462886220107515,2.005673427136211]
sigma matrix is 0.49328013505428253 0.050374713498113975 0.3573203540815462 0.050018569392196975
0.050374713498113975 0.04009423452907058 0.00416971505937197 0.02000523766170409
0.3573203540815462 0.00416971505937197 0.33772537665488306 0.017006917604832562
0.050018569392196975 0.02000523766170409 0.017006917604832562 0.06935869650451881
Component 2 : weight is 0.508309
mu vector is
[6.130726266791161,2.872742630634873,4.675369349848198,1.5732931362538298]
sigma matrix is 0.34423978263401117 0.14332952213838432 0.3498831855148551 0.1447023418962832
0.14332952213838432 0.13127254135549662 0.1848327285944271 0.09799971374720898
0.3498831855148551 0.1848327285944271 0.5558476131836437 0.2698797562441122
0.1447023418962832 0.09799971374720898 0.2698797562441122 0.16825697031717957