还没时间整理,先把代码贴上来。
package spark.examples;
import spark.api.java.*;
import spark.api.java.function.*;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.ArrayList;
import java.util.Random;
import java.util.Arrays;
import java.io.Serializable;
import spark.util.Vector;
import scala.Tuple2;
public class JavaKMeans {
private static HashMap> kPoint = new HashMap>();
static class Closest implements Serializable {
int i;
Iterable iter;
public Closest(Iterable dd, int ii) {
this.i = ii;
this.iter = dd;
}
public Closest add(Closest clt) {
int ii = this.i + clt.i;
Iterator titer = this.iter.iterator();
Iterator oiter = clt.iter.iterator();
Iterator tmp = this.iter.iterator();
int count = 0;
while (tmp.hasNext()) {
tmp.next();
count ++;
}
Double[] dd = new Double[count];
count = 0;
while (titer.hasNext() && oiter.hasNext()) {
Double me = titer.next();
Double other = oiter.next();
dd[count ++] = me + other;
}
Iterable newiter = Arrays.asList(dd);
return new Closest(newiter, ii);
}
public Iterable division() {
Iterator tmp = iter.iterator();
Iterator iterator = iter.iterator();
int count = 0;
while(tmp.hasNext()) {
tmp.next();
count ++;
}
Double[] dd = new Double[count];
count = 0;
while (iterator.hasNext()) {
dd[count ++] = iterator.next() / this.i;
}
return Arrays.asList(dd);
}
}
public static void main(String[] args) {
if (args.length < 4) {
System.err
.println("Usage: SparkLocalKMeans ");
System.exit(1);
}
int k = Integer.parseInt(args[2]);
double converge = Double.parseDouble(args[3]);
double tmpDist = 1.0;
JavaSparkContext sc = new JavaSparkContext(args[0], "SparkKmeans");
JavaRDD lines = sc.textFile(args[1]);
class Split extends FlatMapFunction {
public Iterable call(String t) throws Exception {
Iterable str = Arrays.asList(t.split(" "));
Iterator iter = str.iterator();
int count = 0;
while(iter.hasNext()) {
iter.next();
count ++;
}
Double[] dd = new Double[count];
Iterator iterator = str.iterator();
count = 0;
while (iterator.hasNext()) {
dd[count ++] = Double.parseDouble(iterator.next());
}
return Arrays.asList(dd);
}
}
JavaRDD> digts = lines.map(new Split());
/*List> dd = digts.collect();
for (Iterable iarr : dd) {
Iterator iter = iarr.iterator();
while(iter.hasNext()) {
Double D = iter.next();
System.out.print(D + "\t");
}
System.out.println();
}*/
digts.cache();
List> points = digts.takeSample(false, k, 42);
int e = 1;
for (Iterable iter : points) {
System.out.print(e + "\t");
kPoint.put(e++, iter);
Iterator ator = iter.iterator();
while (ator.hasNext()) {
System.out.print(ator.next() + ", ");
}
System.out.println();
}
// above code is OK !!!!!!!!!!
class tupleCls extends PairFunction, Integer, Closest> {
public Tuple2 call(Iterable t) throws Exception {
Iterator tt = t.iterator();
while (tt.hasNext()) {
System.out.print(tt.next() + ", ");
}
System.out.println();
int bestIndex = 0;
double closest = Double.MAX_VALUE;
for (int i = 1; i <= kPoint.size(); ++i) {
double tmpDist = 0.0;
Iterator iterkp = kPoint.get(i).iterator();
Iterator iterp = t.iterator();
while(iterkp.hasNext() && iterp.hasNext()) {
double x1 = iterkp.next();
double x2 = iterp.next();
tmpDist += (x1 - x2) * (x1 - x2);
}
if (tmpDist < closest) {
System.out.println("dist is :" + tmpDist + " and index is " + i);
closest = tmpDist;
bestIndex = i;
}
}
return new Tuple2(bestIndex, new Closest(t, 1));
}
}
class reduceCls extends Function2 {
public Closest call(Closest x1, Closest x2) throws Exception {
return x1.add(x2);
//return null;
}
}
class newPoint extends PairFunction, Integer, Iterable> {
public Tuple2> call(Tuple2 clt) throws Exception {
Iterable iter = clt._2.division();
return new Tuple2(clt._1, iter);
}
}
int count = 0;
while (tmpDist > converge) {
JavaPairRDD closest = digts.map(new tupleCls());
System.out.println("=======================================> After first map ...");
List> dd = closest.collect();
for (Tuple2 tuple : dd) {
System.out.print(tuple._1 + "\t");
Closest clt = (Closest) tuple._2;
Iterator iter = clt.iter.iterator();
while (iter.hasNext()) {
System.out.print(iter.next() + ", ");
}
System.out.println(clt.i);
}
JavaPairRDD pointStats = closest.reduceByKey(new reduceCls());
System.out.println("=======================================> After reduceByKey ...");
List> ddd = pointStats.collect();
for (Tuple2 tuple : ddd) {
System.out.print(tuple._1 + "\t");
Closest clt = (Closest) tuple._2;
Iterator iter = clt.iter.iterator();
while (iter.hasNext()) {
System.out.print(iter.next() + ", ");
}
System.out.println(clt.i);
}
JavaPairRDD> newPoints = pointStats.map(new newPoint());
List>> lnewPoints = newPoints.collect();
System.out.println("=======================================> After second map ...");
for (Tuple2 tuple : lnewPoints) {
System.out.println(tuple._1 + "\t");
Iterable iter = (Iterable)tuple._2;
Iterator ator = iter.iterator();
while (ator.hasNext()) {
System.out.print(ator.next() + ",");
}
System.out.println();
}
tmpDist = 0.0;
for (Tuple2 tuple : lnewPoints) {
Iterator oldp = kPoint.get(tuple._1).iterator();
Iterator newp = ((Iterable)tuple._2).iterator();
while (oldp.hasNext() && newp.hasNext()) {
double x1 = oldp.next();
double x2 = newp.next();
tmpDist += (x1 - x2) * (x1 - x2);
}
System.out.println("tmpDist is " + tmpDist);
kPoint.put((Integer)tuple._1, (Iterable)tuple._2);
}
count ++;
//if (count > 2) break;
}
for (int i = 1; i <= kPoint.size(); ++i) {
System.out.print(i + ": ");
Iterator iter = kPoint.get(i).iterator();
while (iter.hasNext()) {
System.out.print(iter.next() + ",");
}
System.out.println();
}
System.exit(0);
}
}