高斯混合模型
高斯混合模型(Gaussian Mixture Model), 简称为GMM,是一个基于概率密度的模型。在这种模型中,数据点是由K个正态分布所生成的,每个正态分布都拥有自己的均值和协方差矩阵,而来自每个高斯分布的数据点的比例有先验r决定。与k-means聚类最大的不同在,k-means的结果是每个数据点都分布到唯一的cluster中,而GMM则给出这些数据点被分配到每个cluster的概率,因此高斯混合模型聚类属于软聚类的一种。
关于高斯混合模型的原理介绍,我推荐两份博客:
http://blog.csdn.net/sunanger_wang/article/details/885276
http://blog.sina.com.cn/s/blog_54d460e40101ec00.html
而Peter Flach的《机器学习》中在概率模型章节中也有谈到高斯混合模型
关于spark中mllib中的GMM
我运用的spark1.6.2版,使用的是JAVA语言
代码如下:
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.clustering.GaussianMixture;
import org.apache.spark.mllib.clustering.GaussianMixtureModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;
import scala.Tuple2;
public class MyGaussian {
public static void main(String[] args){
SparkConf conf = new SparkConf().setAppName("GMM").setMaster("local");
conf.set("spark.testing.memory", "2147480000");
JavaSparkContext sc = new JavaSparkContext(conf);
//读取文件,加载数据
String path = "/home/quincy1994/桌面/test.txt";
JavaRDD<String> data = sc.textFile(path);
JavaRDD<Vector> parsedData = data.map(new DataTovector());
parsedData.cache(); //暂存训练集数据
//训练混合高斯模型
GaussianMixtureModel gmm = new GaussianMixture().setK(4).run(parsedData.rdd());
parsedData.cache();
//输出结果
RDD<Vector> points = parsedData.rdd();
JavaRDD resultRDD = new JavaRDD(gmm.predictSoft(points), null);
List<double[]> result= resultRDD.collect();
int count = 0;
for(double[] one: result){
System.out.print("point "+ count + ":");
for(double pro: one){
System.out.print(pro + " ");
}
System.out.println();
count += 1;
}
}
static class DataTovector implements Function<String, Vector>{
public Vector call(String s) throws Exception {
// TODO Auto-generated method stub
String[] array = s.trim().split(" ");
double[] values = new double[array.length];
for(int i = 0; i<array.length; i++){
values[i] = Double.parseDouble(array[i]);
}
return Vectors.dense(values);
}
}
static class Predict implements PairFunction<Vector, Vector, Integer>{
GaussianMixtureModel gmm;
public Predict(GaussianMixtureModel gmm){
this.gmm = gmm;
}
public Tuple2<Vector, Integer> call(Vector v) throws Exception {
// TODO Auto-generated method stub
int tag = gmm.predict(v);
return new Tuple2<Vector, Integer>(v, tag);
}
}
}
测试数据如下:
90.0 90.0
90.1 90.1
90.2 90.1
90.2 90.2
当时设置K的值为4,训练出来的结果如下:
point 0:0.5 9.282585397132313E-20 0.5 9.282585397132313E-20
point 1:0.5 9.281685478536645E-20 0.5 9.281685478536645E-20
point 2:4.292739440223956E-22 0.020398093383873512 4.292739440223956E-22 0.9796019066161266
point 3:2.4130589663026416E-5 2.4443552917481E-22 2.4130589663026416E-5 0.999951738820674