在完成sparkMllib GMM算法例子之前需要知道几个概念。1、高斯分布、2、多维高斯分布。3、高斯混合分布。4、协方差
GMM称为混合高斯分布,它在单高斯分布(又称正太分布,一维正太分布)的基础上针对多元变量发展出来的。(以下参考了百度词条内容)
1)单高斯分布公式:
,该公式的推导以及意义大家可以自行百度,这里只讲一下各个参数在公式中的意义:
μ是正态分布的位置参数,描述正态分布的集中趋势位置。概率规律为取与μ邻近的值的概率大,而取离μ越远的值的概率越小。正态分布以X=μ为对称轴,左右完全对称。正态分布的期望、均数、中位数、众数相同,均等于μ。
σ描述正态分布资料数据分布的离散程度,σ越大,数据分布越分散,σ越小,数据分布越集中。也称为是正态分布的形状参数,σ越大,曲线越扁平,反之,σ越小,曲线越瘦高。
2)多维单高斯分布公式:
由上面的定义可知,多维单高斯分布的方差其实是协方差矩阵。
3)高斯混合分布:就是多个高斯分布(可能是单高斯也可能是多维高斯)的组合。下面是李航老师在《统计学习方法》
由上图公式可知,高斯混合分布多了一个参数,该参数就是每单高斯分布在高斯混合分布里面的权重。
4)协方差矩阵的含义可以参考该篇博文:http://blog.csdn.net/yangdashi888/article/details/52397990
sparkMllib GMM算法就是根据一批给定的随机变量,每个随机变量肯能是一维的,也可能是多维的,然后求出高斯混合分布中的三个参数:
1、a权重。2、μ(如果是多维就是一个数组)3、方差(一维)/协方差矩阵(多维)
以下是sparkMllib GMM的例子。
1、数据gmm_data.txt中是二维数据,部分数据展示如下:
2.59470454e+00 2.12298217e+00
1.15807024e+00 -1.46498723e-01
2.46206638e+00 6.19556894e-01
-5.54845070e-01 -7.24700066e-01
-3.23111426e+00 -1.42579084e+00
2、数据gmm_data1.txt中是二维数据,部分数据展示如下:
2.59470454e+00
2.12298217e+00
1.15807024e+00
-1.46498723e-01
案例代码如下:
package spark;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
// $example on$
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.clustering.GaussianMixture;
import org.apache.spark.mllib.clustering.GaussianMixtureModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
public class JavaGaussianMixtureExample {
public static void main(String[] args) {
Logger logger = Logger.getLogger(JavaGaussianMixtureExample.class);
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF);
SparkConf conf = new SparkConf().setMaster("local[2]").setAppName("JavaGaussianMixtureExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
String path = "F:/spark-2.1.0-bin-hadoop2.6/data/mllib/gmm_data1.txt";
JavaRDD<String> data = jsc.textFile(path);
JavaRDD<Vector> parsedData = data.map(f->{
return Vectors.dense(Double.parseDouble(f.trim()));
});
parsedData.cache();
/**
* k指定了高斯混合分布中的高斯分布个数。
*/
GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());
for (int j = 0; j < gmm.k(); j++) {
System.out.printf("一维混合高斯分布得到的数据如下:\nweight=%f\nmu=%s\nsigma=\n%s\n", gmm.weights()[j], gmm.gaussians()[j].mu(),
gmm.gaussians()[j].sigma());
}
logger.info("split line =====================================");
String path2 = "F:/spark-2.1.0-bin-hadoop2.6/data/mllib/gmm_data.txt";
JavaRDD<String> data2 = jsc.textFile(path2);
JavaRDD<Vector> parsedData2 = data2.map(s -> {
String[] sarray = s.trim().split(" ");
double[] values = new double[sarray.length];
for (int i = 0; i < sarray.length; i++) {
values[i] = Double.parseDouble(sarray[i]);
}
return Vectors.dense(values);
});
parsedData2.cache();
GaussianMixtureModel gmm2 = new GaussianMixture().setK(2).run(parsedData2.rdd());
for (int j = 0; j < gmm2.k(); j++) {
System.out.printf("二维混合高斯分布得到的数据如下:\nweight=%f\nmu=%s\nsigma=\n%s\n", gmm2.weights()[j], gmm2.gaussians()[j].mu(),
gmm2.gaussians()[j].sigma());
}
jsc.stop();
}
}
执行结果如下: