package zqr.com; import org.apache.spark.Accumulator; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.mllib.clustering.KMeans; import org.apache.spark.mllib.clustering.KMeansModel; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import scala.Tuple2; import java.util.List; import java.util.Map; public class KmeansSF { private static int oneCount=0; private static int towCount=0; // public static volatile Broadcast<List<String>> brodcastList=null; // public static volatile Accumulator<Integer> accumulator=null; public static void main(String[]args){ deeping(); } public static JavaSparkContext init(){ SparkConf conf = new SparkConf().setAppName("kmeans-SF").setMaster("local"); JavaSparkContext sc = new JavaSparkContext(conf); sc.setLogLevel("WARN"); return sc; } public static void deeping(){ JavaSparkContext jsc=init(); String path = "/usr/local/spark/data/mllib/kmeans_data.txt"; JavaRDD<String> data = jsc.textFile(path); JavaRDD<Vector> parsedData = data.map(s -> { String[] sarray = s.split(" "); double[] values = new double[sarray.length]; for (int i = 0; i < sarray.length; i++) { values[i] = Double.parseDouble(sarray[i]); } return Vectors.dense(values); }); parsedData.cache(); // Cluster the data into two classes using KMeans int numClusters = 2; int numIterations = 20; KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations); //=================================================method1 // parsedData.collect().forEach(x->System.out.println(x+":"+clusters.predict(x))); // List<Vector> list=parsedData.collect(); // for ( Vector x:list){ // int pd=clusters.predict(x); // if(pd==0){ // oneCount++; // }else { // towCount++; // } // } // //================================================================= // int x=clusters.predict(Vectors.dense(new double[]{0.2,0.3,0.8})); // System.out.println("[0.2,0.3,0.8]:"+x); //++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ //+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ //=================================================method2 // // JavaPairRDD<Vector,Integer> comebain=parsedData.mapToPair(x->{ // int pd=clusters.predict(x); // if(pd==0){ // oneCount++; // }else { // towCount++; // } // return new Tuple2<Vector, Integer>(x,pd); // }); // // comebain.collect().forEach(System.out::println); // // // int flag=0; // System.out.println("Cluster centers:"); // for (Vector center: clusters.clusterCenters()) { // if(flag==0) { // System.out.println(" " + center+" "+oneCount); // flag=1; // }else{ // System.out.println(" " + center+" "+towCount); // } // } //=================================================method3 JavaPairRDD<Integer,Vector> conuntbyk=parsedData.mapToPair(x->{ int pd=clusters.predict(x); return new Tuple2<Integer,Vector>(pd,x); }); Map<Integer,Long> map=conuntbyk.countByKey(); for (Map.Entry<Integer,Long> entry : map.entrySet()) { System.out.println("Key = " + entry.getKey() + ", Value = " + entry.getValue()); } System.out.println("Cluster centers:"); for (Vector center: clusters.clusterCenters()) { System.out.println(" " + center); } double cost = clusters.computeCost(parsedData.rdd()); System.out.println("Cost: " + cost); // Evaluate clustering by computing Within Set Sum of Squared Errors double WSSSE = clusters.computeCost(parsedData.rdd()); System.out.println("Within Set Sum of Squared Errors = " + WSSSE); // Save and load model // clusters.save(jsc.sc(), "target/org/apache/spark/JavaKMeansExample/KMeansModel"); // KMeansModel sameModel = KMeansModel.load(jsc.sc(), // "target/org/apache/spark/JavaKMeansExample/KMeansModel"); } }
spark kmeans java实现
最新推荐文章于 2024-05-30 08:43:01 发布