利用java的spark做高斯混合模型聚类

高斯混合模型

高斯混合模型(Gaussian Mixture Model), 简称为GMM,是一个基于概率密度的模型。在这种模型中,数据点是由K个正态分布所生成的,每个正态分布都拥有自己的均值和协方差矩阵,而来自每个高斯分布的数据点的比例有先验r决定。与k-means聚类最大的不同在,k-means的结果是每个数据点都分布到唯一的cluster中,而GMM则给出这些数据点被分配到每个cluster的概率,因此高斯混合模型聚类属于软聚类的一种。

关于高斯混合模型的原理介绍,我推荐两份博客:
http://blog.csdn.net/sunanger_wang/article/details/885276
http://blog.sina.com.cn/s/blog_54d460e40101ec00.html
而Peter Flach的《机器学习》中在概率模型章节中也有谈到高斯混合模型

关于spark中mllib中的GMM

我运用的spark1.6.2版,使用的是JAVA语言
代码如下:

import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.clustering.GaussianMixture;
import org.apache.spark.mllib.clustering.GaussianMixtureModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;
import scala.Tuple2;


public class MyGaussian {

    public static void main(String[] args){
        SparkConf conf = new SparkConf().setAppName("GMM").setMaster("local");
        conf.set("spark.testing.memory", "2147480000");
        JavaSparkContext sc = new JavaSparkContext(conf);

        //读取文件,加载数据
        String path = "/home/quincy1994/桌面/test.txt";
        JavaRDD<String> data = sc.textFile(path);
        JavaRDD<Vector> parsedData  = data.map(new DataTovector());
        parsedData.cache(); //暂存训练集数据

        //训练混合高斯模型
        GaussianMixtureModel gmm = new GaussianMixture().setK(4).run(parsedData.rdd());
        parsedData.cache();

        //输出结果
        RDD<Vector> points = parsedData.rdd();
        JavaRDD resultRDD = new JavaRDD(gmm.predictSoft(points), null);
        List<double[]> result= resultRDD.collect();
        int count = 0;
        for(double[] one: result){
            System.out.print("point "+ count + ":");
            for(double pro: one){
                System.out.print(pro + " ");
            }
            System.out.println();
            count += 1;
        }
    }

    static class DataTovector implements Function<String, Vector>{
        public Vector call(String s) throws Exception {
            // TODO Auto-generated method stub
            String[] array = s.trim().split(" ");
            double[] values  = new double[array.length];
            for(int i = 0; i<array.length; i++){
                values[i] = Double.parseDouble(array[i]);
            }
            return Vectors.dense(values);
        }
    }

    static class Predict implements PairFunction<Vector, Vector, Integer>{

        GaussianMixtureModel gmm;

        public Predict(GaussianMixtureModel gmm){
            this.gmm = gmm;
        }

        public Tuple2<Vector, Integer> call(Vector v) throws Exception {
            // TODO Auto-generated method stub
            int tag = gmm.predict(v);
            return new Tuple2<Vector, Integer>(v, tag);
        }

    }
}

测试数据如下:

90.0 90.0
90.1 90.1
90.2 90.1
90.2 90.2

当时设置K的值为4,训练出来的结果如下:

point 0:0.5 9.282585397132313E-20 0.5 9.282585397132313E-20
point 1:0.5 9.281685478536645E-20 0.5 9.281685478536645E-20
point 2:4.292739440223956E-22 0.020398093383873512 4.292739440223956E-22 0.9796019066161266
point 3:2.4130589663026416E-5 2.4443552917481E-22 2.4130589663026416E-5 0.999951738820674

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值