用kmeans split 连续值(2)

把KMeans分割连续值想了一下,center的边缘才是分割的界限,上篇的微博完全不对。折旧需要在做KMean 是计算最大距离和最小距离(距离有方向),现在的spark mllib

中如果想计算最大和最小距离,那么需要在最后结果在遍历RDD的每个数据,算出属于哪个center,然后算出每个cennter的最大和最小距离。

我重新用JAva写了一下KMeans,RDD的输入是RDD<Double>, 在model中加入了split方法,三个类SimpleKMeans、SimpleKMeansModel和SimpleLocalKMeans, 代码如下,逻辑和Spark的KMeans是完全一样的


/**
 * SimpleLocalKMeans Class
 */


public class SimpleLocalKMeans
{


public static SimpleVector[] kMeansPlusPlus( int seed, SimpleVector[] points,
Double[] weights, int k, int maxIterations )
{


Random random = new Random( seed );
int dimensions = 1;


SimpleVector[] centers = new SimpleVector[k];
centers[0] = pickWeighted( random, points, weights );


for ( int i = 1; i < k; i++ )
{
SimpleVector[] curCenters = takeValue( centers, i  );
double sum = sumValue( curCenters, points, weights );


double r = random.nextDouble( ) * sum;


double cumulativeScore = 0.0;
int j = 0;
while ( j < points.length && cumulativeScore < r )
{
cumulativeScore += weights[j]
* SimpleKMeans.findClosest( curCenters, points[j] )._2;
j++;
}


if ( j == 0 )
{
centers[i] = points[0].toDense( );
}
else
{
centers[i] = points[j - 1].toDense( );
}
}


Integer[] oldClosest = (Integer[]) Array.fill( points.length, new MyFunction0<Integer>( )
{


@Override
public Integer apply( )
{
return -1;
}
}, ClassManifestFactory.classType( Integer.class ) );


int iteration = 0;
boolean moved = true;


while ( moved && iteration < maxIterations )
{
moved = false;
Double[] counts = (Double[]) Array.fill( k, new MyFunction0<Double>( )
{


@Override
public Double apply( )
{


return 0.0;
}
}, ClassManifestFactory.classType( Double.class ) );


Double[] sums = (Double[]) Array.fill( k, new MyFunction0<Double>( )
{


@Override
public Double apply( )
{
//return Vectors.zeros( dimensions );
return 0.0;
}
}, ClassManifestFactory.classType( Double.class ) );


int i = 0;
while ( i < points.length )
{
SimpleVector p = points[i];


int index = SimpleKMeans.findClosest( centers, p )._1;
//BLAS.axpy( weights[i], Vectors.dense( new double[]{p.vector( )}), sums[index] );
sums[index] = sums[index] + weights[i]*p.vector( );
counts[index] += weights[i];


if ( index != oldClosest[i] )
{
moved = true;
oldClosest[i] = index;
}
i += 1;


}


int j = 0;
while ( j < k )
{
if ( counts[j] == 0.0 )
{


centers[j] = points[random.nextInt( points.length )].toDense( );
}
else
{
//BLAS.scal( 1.0 / counts[j], sums[j] );
sums[j] = sums[j]/counts[j];
centers[j] = new SimpleVector( sums[j] );
}
j += 1;
}
iteration += 1;
}

if ( iteration == maxIterations )
{
//Log
}
else
{
//Log
}


return centers;
}


private static double sumValue( SimpleVector[] curCenters,
SimpleVector[] points, Double[] weights )
{
double retValue = 0.0;
for ( int i = 0; i < points.length; i++ )
{
retValue = retValue
+ weights[i]
* SimpleKMeans.findClosest( curCenters, points[i] )._2;
}
return retValue;
}


private static <T> T pickWeighted( Random random, T[] data,
Double[] weights )
{
double randomValue = random.nextDouble( ) * sumArray( weights );
int i = 0;
double curWeight = 0.0;
while ( i < weights.length && curWeight < randomValue )
{
curWeight = curWeight + weights[i];
i++;
}
return data[i - 1];
}


private static double sumArray( Double[] array )
{
double retValue = 0.0;
for ( int i = 0; i < array.length; i++ )
{
retValue = retValue + array[i];
}


return retValue;
}


private static SimpleVector[] takeValue( SimpleVector[] array, int n )
{
SimpleVector[] retValue = new SimpleVector[n];
for ( int i = 0; i < n; i++ )
{
retValue[i] = array[i];
}


return retValue;
}


private static abstract class MyFunction0<R> extends AbstractFunction0<R> implements Serializable
{


}
}


import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;


import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.rdd.RDD;


import amlib.simplekm.SimpleKMeans.SimpleVector;
import scala.reflect.ClassManifestFactory;


/**
 * 
 */


public class SimpleKMeansModel implements Serializable
{
private SimpleKMeans.SimpleVector[] clusterCentersWithNorm;

public SimpleKMeansModel(SimpleVector[] vectors)
{
this.clusterCentersWithNorm = vectors;
Arrays.sort( clusterCentersWithNorm, new Comparator<SimpleVector>( )
{


@Override
public int compare( SimpleVector t1, SimpleVector t2 )
{
return t1.vector( ) - t2.vector( ) >= 0? 1:-1;
}
});
}


public double[] clusterCenters( )
{
double[] retValue = new double[clusterCentersWithNorm.length];
for (int i=0; i<clusterCentersWithNorm.length; i++)
{
retValue[i] = clusterCentersWithNorm[i].vector( );
}
return retValue;
}



public void setClusterCentersWithNorm(
SimpleKMeans.SimpleVector[] clusterCentersWithNorm )
{
this.clusterCentersWithNorm = clusterCentersWithNorm;
}


public int predict(double point)
{
return SimpleKMeans.findClosest( clusterCentersWithNorm, new SimpleKMeans.SimpleVector(point ) )._1;
}

public RDD<Integer> predict(RDD<Double> points)
{
Broadcast<SimpleKMeans.SimpleVector[]> bcClusterCentersWithNorm = points.context( ).broadcast( clusterCentersWithNorm, 
ClassManifestFactory.classType( SimpleKMeans.SimpleVector[].class ) );

return points.toJavaRDD( ).map( s->SimpleKMeans.findClosest( bcClusterCentersWithNorm.getValue( ), new SimpleKMeans.SimpleVector(s) )._1 ).rdd( );
}

public Double[] split()
{
if (clusterCentersWithNorm.length < 2)
{
return new Double[]{};
}
Double[] retValue = new Double[clusterCentersWithNorm.length - 1];
for (int i=0; i<clusterCentersWithNorm.length - 1; i++)
{
retValue[i] = (clusterCentersWithNorm[i].vector( ) + clusterCentersWithNorm[i].max( ) + 
clusterCentersWithNorm[i + 1].vector( ) + clusterCentersWithNorm[i+1].min( ))/2.0;
}

return retValue;

}
private static class Cluster
{


private int id;
private Vector point;


public Cluster( int id, Vector point )
{
super( );
this.id = id;
this.point = point;
}
}


}


import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Consumer;


import org.apache.spark.Accumulator;
import org.apache.spark.AccumulatorParam;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.StorageLevels;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;
import org.apache.spark.util.random.XORShiftRandom;


import scala.Array;
import scala.Option;
import scala.Some;
import scala.Tuple2;
import scala.Tuple4;
import scala.reflect.ClassManifestFactory;
import scala.runtime.AbstractFunction0;
import scala.runtime.AbstractFunction1;
import scala.runtime.AbstractFunction2;


/**
 * 
 */


public class SimpleKMeans implements Serializable
{


public static final String RANDOM = "random";
public static final String K_MEANS_PARALLEL = "k-means||";


private int k;
private int maxIterations;
private int runs;
private String initializationMode;
private int initializationSteps;
private double epsilon;
private long seed;


public SimpleKMeans( int k, int maxIterations, int runs, String initializationMode, int initializationSteps, double epsilon, long seed )
{
super( );
this.k = k;
this.maxIterations = maxIterations;
this.runs = runs;
this.initializationMode = initializationMode;
this.initializationSteps = initializationSteps;
this.epsilon = epsilon;
this.seed = seed;
}


public SimpleKMeans( )
{
this( 2, 20, 1, K_MEANS_PARALLEL, 5, 1e-4, new Random( ).nextLong( ) );
}


public int getK( )
{
return k;
}


public SimpleKMeans setK( int k )
{
this.k = k;
return this;
}


public int getMaxIterations( )
{
return maxIterations;
}


public SimpleKMeans setMaxIterations( int maxIterations )
{
this.maxIterations = maxIterations;
return this;
}


public int getRuns( )
{
return runs;
}


public SimpleKMeans setRuns( int runs )
{
this.runs = runs;
return this;
}


public String getInitializationMode( )
{
return initializationMode;
}


public SimpleKMeans setInitializationMode( String initializationMode )
{
this.initializationMode = initializationMode;
return this;
}


public int getInitializationSteps( )
{
return initializationSteps;
}


public SimpleKMeans setInitializationSteps( int initializationSteps )
{
this.initializationSteps = initializationSteps;
return this;
}


public double getEpsilon( )
{
return epsilon;
}


public SimpleKMeans setEpsilon( double epsilon )
{
this.epsilon = epsilon;
return this;
}


public long getSeed( )
{
return seed;
}


public SimpleKMeans setSeed( long seed )
{
this.seed = seed;
return this;
}


public static String getRandom( )
{
return RANDOM;
}


public static String getkMeansParallel( )
{
return K_MEANS_PARALLEL;
}


public static Tuple2<Integer, Double> findClosest( SimpleVector[] centers,
SimpleVector point )
{
double bestDistance = Double.POSITIVE_INFINITY;
int bestIndex = 0;
for ( int i = 0; i < centers.length; i++ )
{


double distance = Math.abs( centers[i].vector( ) - point.vector( ) );
if ( distance < bestDistance )
{
bestDistance = distance;
bestIndex = i;
}


}
return new Tuple2<Integer, Double>( bestIndex, bestDistance );
}


public SimpleKMeansModel run( RDD<Double> data )
{
// RDD<Double> norms = data.toJavaRDD( ).map( s -> Vectors.norm( s ) ).rdd( );
// norms.cache( );


// RDD<SimpleVector> zippedData = data.zip( norms, ClassManifestFactory.classType( Double.class ) ).toJavaRDD( ).map( s -> new SimpleVector( s._1 ) ).rdd( );
RDD<SimpleVector> zippedData = data.toJavaRDD( ).map( s->new SimpleVector( s ) ).rdd( );
SimpleKMeansModel model = runAlgorithm( zippedData );


//norms.unpersist( true );


return model;
}


private SimpleKMeansModel runAlgorithm( RDD<SimpleVector> data )
{
SparkContext sc = data.context( );
sc.env( ).blockManager( ).master( ).removeRdd( 1, true );
int numRuns = runs;
SimpleVector[][] centers = null;
if ( RANDOM.equals( initializationMode ) )
{
centers = initKMeansRandom( data );
}
else
{
centers = initKMeansParallel( data );
}
Boolean[] active = (Boolean[]) Array.fill( numRuns, new MyFunction0<Boolean>( )
{


@Override
public Boolean apply( )
{
return true;
}
}, ClassManifestFactory.classType( Boolean.class ) );


Double[] costs = (Double[]) Array.fill( numRuns, new MyFunction0<Double>( )
{


@Override
public Double apply( )
{
return 0.0;
}
}, ClassManifestFactory.classType( Double.class ) );


List<Integer> activeRuns = new ArrayList<Integer>( );
initArraybuffer( activeRuns, numRuns );


int iteration = 0;


while ( iteration < maxIterations && !activeRuns.isEmpty( ) )
{


// activeRuns.map( null, null ).
Accumulator<Double>[] costAccums = new Accumulator[activeRuns.size( )];
for ( int i = 0; i < activeRuns.size( ); i++ )
{
costAccums[i] = sc.accumulator( 0.0, doubleAccumulatorParam );
}
SimpleVector[][] activeCenters = new SimpleVector[activeRuns.size( )][];
for ( int i = 0; i < activeRuns.size( ); i++ )
{
activeCenters[i] = centers[activeRuns.get( i )];
}


Broadcast<SimpleVector[][]> bcActiveCenters = sc.broadcast( activeCenters, ClassManifestFactory.classType( SimpleVector[][].class ) );


Map<Tuple2<Integer, Integer>, Tuple4<Double, Long, Double, Double>> map = data.toJavaRDD( ).mapPartitions( new FlatMapFunction<Iterator<SimpleVector>, Tuple2<Tuple2<Integer, Integer>, Tuple4<Double, Long, Double, Double>>>( )
{


@Override
public Iterable<Tuple2<Tuple2<Integer, Integer>, Tuple4<Double, Long, Double, Double>>>
call( Iterator<SimpleVector> t ) throws Exception
{
SimpleVector[][] thisActiveCenters = bcActiveCenters.getValue( );
int runs = thisActiveCenters.length;
int k = thisActiveCenters[0].length;


//int dims = thisActiveCenters[0][0].vector.size( );
int dims = 1;
// Array.fill( arg0, arg1, arg2, arg3 )
//Vector[][] sums = fillArrayVector( numRuns, k, dims );
Double[][] sums = fillArrayDouble( numRuns, k, 0.0 );
Long[][] counts = fillArrayLong( numRuns, k );
Double[][] mins = fillArrayDouble( numRuns, k, Double.MAX_VALUE );
Double[][] maxs = fillArrayDouble( numRuns, k, Double.MIN_VALUE );
while ( t.hasNext( ) )
{
SimpleVector point = t.next( );
for ( int i = 0; i < runs; i++ )
{
Tuple2<Integer, Double> value = findClosest( thisActiveCenters[i], point );
int bestCenter = value._1;
double cost = value._2;
costAccums[i].add( cost );


//double sum = sums[i][bestCenter];
//BLAS.axpy( 1.0, Vectors.dense( new double[]{point.vector( )}), sum );
sums[i][bestCenter] = sums[i][bestCenter] + point.vector( );
counts[i][bestCenter] = counts[i][bestCenter] + 1;

double distance =  point.vector( ) - thisActiveCenters[i][bestCenter].vector( );
if (distance < mins[i][bestCenter])
{
mins[i][bestCenter] = distance;
}
if (distance > maxs[i][bestCenter])
{
maxs[i][bestCenter] = distance;
}
}
}


List<Tuple2<Tuple2<Integer, Integer>, Tuple4<Double, Long, Double, Double>>> list = new ArrayList<Tuple2<Tuple2<Integer, Integer>, Tuple4<Double, Long, Double, Double>>>( );
for ( int i = 0; i < runs; i++ )
{
for ( int j = 0; j < k; j++ )
{
Tuple2<Integer, Integer> key = new Tuple2<Integer, Integer>( i, j );
Tuple4<Double, Long, Double, Double> value = new Tuple4<Double, Long, Double, Double>( sums[i][j], counts[i][j], mins[i][j], maxs[i][j] );
list.add( new Tuple2<Tuple2<Integer, Integer>, Tuple4<Double, Long, Double, Double>>( key, value ) );
}
}


return list;
}
} ).mapToPair( new PairFunction<Tuple2<Tuple2<Integer, Integer>, Tuple4<Double, Long, Double, Double>>, Tuple2<Integer, Integer>, Tuple4<Double, Long, Double, Double>>( )
{


@Override
public Tuple2<Tuple2<Integer, Integer>, Tuple4<Double, Long, Double, Double>>
call( Tuple2<Tuple2<Integer, Integer>, Tuple4<Double, Long, Double, Double>> t )
throws Exception
{
return t;
}
} ).reduceByKey( new Function2<Tuple4<Double, Long, Double, Double>, Tuple4<Double, Long, Double, Double>, Tuple4<Double, Long, Double, Double>>( )
{


@Override
public Tuple4<Double, Long, Double, Double> call( Tuple4<Double, Long, Double, Double> v1,
Tuple4<Double, Long, Double, Double> v2 ) throws Exception
{
return mergeContribs( v1, v2 );
}


} ).collectAsMap( );


bcActiveCenters.unpersist( false );


for ( int i = 0; i < activeRuns.size( ); i++ )
{
int run = activeRuns.get( i );
boolean changed = false;
int j = 0;
while ( j < k )
{
// get i j
Tuple4<Double, Long, Double, Double> entry = map.get( new Tuple2<Integer, Integer>( i, j ) );


Double sum = entry._1( );
Long count = entry._2( );
if ( count != 0 )
{
//BLAS.scal( 1.0 / count, sum );
sum = sum/count;
SimpleVector newCenter = new SimpleVector( sum );
newCenter.min = entry._3( );
newCenter.max = entry._4( );
if ( fastSquaredDistance( newCenter, centers[run][j] ) > epsilon
* epsilon )
{
changed = true;
}


centers[run][j] = newCenter;
}
j++;
// map.entrySet( );
}


if ( !changed )
{
active[run] = false;
}
costs[run] = costAccums[i].value( );
}


// a.stream( ).collect( a );
// sc.parallelize( activeRuns.toSeq( ), 5,
// ClassManifestFactory.classType( Integer.class ) ).toJavaRDD(
// ).filter( )
// Seq<Integer> activeSeq = activeRuns.filter( new
// MyFunction1<Integer, Object>( )
// {
//
// @Override
// public Object apply( Integer t )
// {
// return active[t];
// }
//
// }).toSeq( );
// activeRuns.clear( );


// activeRuns.append( activeSeq );


List<Integer> newActive = new ArrayList<Integer>( );


for ( int i : activeRuns )
{
if ( active[i] )
{
newActive.add( i );
}
}
// activeRuns.clear( );
// activeRuns.addAll( );
// activeRuns.add
activeRuns = newActive;
iteration++;
}


int bestRun = finMinValue( costs );


// double[] vectors = new double[centers[bestRun].length];
// for ( int i = 0; i < centers[bestRun].length; i++ )
// {
// vectors[i] = centers[bestRun][i].vector( );
// }

SimpleVector[] vectors = new SimpleVector[centers[bestRun].length];
for ( int i = 0; i < centers[bestRun].length; i++ )
{
vectors[i] = centers[bestRun][i];
}

return new SimpleKMeansModel( vectors );
// activeRuns.append( JavaConverters.asScalaIterableConverter( list ).
// );
}


private static int finMinValue( Double[] costs )
{
int min = 0;
double minValue = Double.MAX_VALUE;
for ( int i = 0; i < costs.length; i++ )
{
if ( costs[i] < minValue )
{
minValue = costs[i];
min = i;
}
}


return min;


}


private static Long[][] fillArrayLong( int runs, int k )
{


Object[] objs = Array.fill( runs, k, new MyFunction0<Long>( )
{


@Override
public Long apply( )
{
return 0L;
}
}, ClassManifestFactory.classType( Long.class ) );


Long[][] retValue = new Long[runs][k];
for ( int i = 0; i < objs.length; i++ )
{
retValue[i] = (Long[]) objs[i];
}
return retValue;
}

private static Double[][] fillArrayDouble( int runs, int k, double value )
{
Object[] objs = Array.fill( runs, k, new MyFunction0<Double>( )
{


@Override
public Double apply( )
{
return value;
}
}, ClassManifestFactory.classType( Double.class ) );


Double[][] retValue = new Double[runs][k];
for ( int i = 0; i < objs.length; i++ )
{
retValue[i] = (Double[]) objs[i];
}
return retValue;
};


private static Vector[][] fillArrayVector( int runs, int k, int dims )
{
Object[] objs = Array.fill( runs, k, new MyFunction0<Vector>( )
{


@Override
public Vector apply( )
{
return Vectors.zeros( dims );
}
}, ClassManifestFactory.classType( Vector.class ) );


Vector[][] retValue = new Vector[runs][k];
for ( int i = 0; i < objs.length; i++ )
{
retValue[i] = (Vector[]) objs[i];
}
return retValue;
}


private static Tuple4<Double, Long, Double, Double> mergeContribs( Tuple4<Double, Long, Double, Double> x,
Tuple4<Double, Long, Double, Double> y )
{
//BLAS.axpy( 1.0, x._1( ), y._1( ) );
double sum = x._1( ) + y._1( );
double min = Math.min( x._3( ), y._3( ) );
double max = Math.max( x._4( ),  y._4( ) );
return new Tuple4<Double, Long, Double, Double>( sum, x._2( ) + y._2( ), min, max );
}


private void initArraybuffer( List<Integer> activeRuns, int run )
{
// List<Integer> list = new ArrayList<Integer>( );
for ( int i = 0; i < run; i++ )
{
activeRuns.add( i );
}


// Object o = JavaConverters.asScalaIterableConverter( list ).asScala(
// ).toSeq( );
// activeRuns.append( JavaConverters.asScalaIterableConverter( list
// ).asScala( ).toSeq( ));
}


public SimpleVector[][] initKMeansParallel( RDD<SimpleVector> data )
{
List<SimpleVector>[] centers = new List[runs];
for ( int i = 0; i < runs; i++ )
{
centers[i] = new ArrayList<SimpleVector>( );
}


RDD<Double[]> costs = data.toJavaRDD( ).map( s ->
{


Double[] retValue = (Double[]) Array.fill( runs, new MyFunction0<Double>( )
{


@Override
public Double apply( )
{
return Double.POSITIVE_INFINITY;
}
}, ClassManifestFactory.classType( Double.class ) );
return retValue;
} ).rdd( );


int newSeed = new XORShiftRandom( seed ).nextInt( );


final List<SimpleVector> list = data.toJavaRDD( ).takeSample( true, runs, newSeed );


List<SimpleVector>[] newCenters = new List[runs];
for ( int i = 0; i < runs; i++ )
{
// centers[i] = new ArrayList<VectorWithNorm>();
// centers[i].add( list.get( i ) );
newCenters[i] = new ArrayList<SimpleVector>( );
newCenters[i].add( list.get( i ) );
}


int step = 0;
while ( step < initializationSteps )
{
Broadcast<List<SimpleVector>[]> bcNewCenters = data.sparkContext( ).broadcast( newCenters, ClassManifestFactory.classType( List[].class ) );
RDD<Double[]> preCosts = costs;
costs = data.zip( preCosts, ClassManifestFactory.classType( Double[].class ) ).toJavaRDD( ).map( s ->
{


Double[] retValue = (Double[]) Array.tabulate( runs, new MyFunction1<Object, Double>( )
{


@Override
public Double apply( Object t )
{
int mark = (int) t;


List<SimpleVector> list = bcNewCenters.getValue( )[mark];
// To Array
return Math.min( findClosest( list.toArray( new SimpleVector[list.size( )] ), s._1 )._2, s._2[mark] );
}
}, ClassManifestFactory.classType( Double.class ) );
return retValue;
} ).rdd( ).persist( StorageLevels.MEMORY_AND_DISK );


Double[] sumCosts = costs.toJavaRDD( ).aggregate( (Double[]) Array.fill( runs, new MyFunction0<Double>( )
{


@Override
public Double apply( )
{
return 0.0;
}
}, ClassManifestFactory.classType( Double.class ) ), new Function2<Double[], Double[], Double[]>( )
{


@Override
public Double[] call( Double[] v1, Double[] v2 )
throws Exception
{
int r = 0;
while ( r < runs )
{
v1[r] = v1[r] + v2[r];
r++;
}
return v1;
}
}, new Function2<Double[], Double[], Double[]>( )
{


@Override
public Double[] call( Double[] v1, Double[] v2 )
throws Exception
{
int r = 0;
while ( r < runs )
{
v1[r] = v1[r] + v2[r];
r++;
}
return v1;
}
} );


bcNewCenters.unpersist( false );
preCosts.unpersist( false );
final int loopStep = step;
List<Option<Tuple2<SimpleVector, Integer[]>>> chosen = data.zip( costs, ClassManifestFactory.classType( Double[].class ) ).toJavaRDD( ).mapPartitionsWithIndex( new Function2<Integer, Iterator<Tuple2<SimpleVector, Double[]>>, Iterator<Option<Tuple2<SimpleVector, Integer[]>>>>( )
{


@Override
public Iterator<Option<Tuple2<SimpleVector, Integer[]>>> call(
Integer index,
Iterator<Tuple2<SimpleVector, Double[]>> pointsWithCosts )
throws Exception
{
XORShiftRandom random = new XORShiftRandom( newSeed
^ ( loopStep << 16 )
^ index );
List<Option<Tuple2<SimpleVector, Integer[]>>> retValue = new ArrayList<Option<Tuple2<SimpleVector, Integer[]>>>( );
while ( pointsWithCosts.hasNext( ) )
{
Tuple2<SimpleVector, Double[]> value = pointsWithCosts.next( );
SimpleVector p = value._1( );
Double[] c = value._2;
List<Integer> rsList = new ArrayList<Integer>( );
for ( int i = 0; i < runs; i++ )
{
if ( random.nextDouble( ) < 2.0
* c[i]
* k
/ sumCosts[i] )
{
rsList.add( i );
}
}
if ( rsList.size( ) != 0 )
{
retValue.add( new Some<Tuple2<SimpleVector, Integer[]>>( new Tuple2<SimpleVector, Integer[]>( p, rsList.toArray( new Integer[rsList.size( )] ) ) ) );
}
else
{
retValue.add( Option.empty( ) );
}
}
return retValue.iterator( );
}
}, false ).collect( );


mergeNewCenters( centers, newCenters, runs );
chosen.forEach( new Consumer<Option<Tuple2<SimpleVector, Integer[]>>>( )
{


@Override
public void accept( Option<Tuple2<SimpleVector, Integer[]>> t )
{
if ( t.isEmpty( ) )
{
return;
}
SimpleVector p = t.get( )._1;
Integer[] rs = t.get( )._2;
for ( int i = 0; i < rs.length; i++ )
{
// newCenters[i].add( p.toDense( ) );
newCenters[rs[i]].add( p.toDense( ) );
}
}
} );
step++;
}


mergeNewCenters( centers, newCenters, runs );
costs.unpersist( false );


Broadcast<List<SimpleVector>[]> bcCenters = data.sparkContext( ).broadcast( centers, ClassManifestFactory.classType( List[].class ) );


Map<Tuple2<Integer, Integer>, Double> weightMap = data.toJavaRDD( ).flatMapToPair( new PairFlatMapFunction<SimpleKMeans.SimpleVector, Tuple2<Integer, Integer>, Double>( )
{


@Override
public Iterable<Tuple2<Tuple2<Integer, Integer>, Double>>
call( SimpleVector point ) throws Exception
{
List<Tuple2<Tuple2<Integer, Integer>, Double>> retValue = new ArrayList<Tuple2<Tuple2<Integer, Integer>, Double>>( );
for ( int i = 0; i < runs; i++ )
{
List<SimpleVector> list = bcCenters.value( )[i];
int best = findClosest( list.toArray( new SimpleVector[list.size( )] ), point )._1;
retValue.add( new Tuple2<Tuple2<Integer, Integer>, Double>( new Tuple2<Integer, Integer>( i, best ), 1.0 ) );
}
return retValue;
}
} ).reduceByKey( ( s1, s2 ) -> s1 + s2 ).collectAsMap( );


bcCenters.unpersist( false );
SimpleVector[][] retValue = new SimpleVector[runs][];
for ( int i = 0; i < runs; i++ )
{
SimpleVector[] myCenters = centers[i].toArray( new SimpleVector[centers[i].size( )] );
Double[] myWeight = new Double[myCenters.length];
for ( int j = 0; j < myCenters.length; j++ )
{
Double value = weightMap.get( new Tuple2<Integer, Integer>( i, j ) );
if ( value == null )
{
value = 0.0;
}


myWeight[j] = value;
}
retValue[i] = SimpleLocalKMeans.kMeansPlusPlus( i, myCenters, myWeight, k, 30 );
}
return retValue;
}


private static void mergeNewCenters( List<SimpleVector>[] centers,
List<SimpleVector>[] newCenters, int runs )
{
for ( int i = 0; i < runs; i++ )
{
centers[i].addAll( newCenters[i] );
newCenters[i].clear( );
}
}


private SimpleVector[][] initKMeansRandom( RDD<SimpleVector> data )
{
XORShiftRandom random = new XORShiftRandom( seed );
int newSeed = random.nextInt( );
final List<SimpleVector> list = data.toJavaRDD( ).takeSample( true, runs
* k, newSeed );


Object obj = Array.tabulate( runs, new MyFunction1<Object, SimpleVector[]>( )
{


@Override
public SimpleVector[] apply( Object t )
{
int r = (Integer) t;
return list.subList( r * k, ( r + 1 )
* k ).toArray( new SimpleVector[k] );
}
}, ClassManifestFactory.classType( SimpleVector[].class ) );


return (SimpleVector[][]) obj;
}


public static class SimpleVector implements Serializable
{


private double vector;
private double max;
private double min;


// double norm;
public SimpleVector( double vector )
{
this.vector = vector;
}


public double vector( )
{
return vector;
}


public double norm( )
{
return vector;
}


public SimpleVector toDense( )
{
return new SimpleVector( vector );


}


public double max( )
{
return max;
}


public SimpleVector setMax( double max )
{
this.max = max;


return this;
}


public double min( )
{
return min;
}


public SimpleVector setMin( double min )
{
this.min = min;
return this;
}


}


private static Double fastSquaredDistance( SimpleVector v1,
SimpleVector v2 )
{
//return MLUtils$.MODULE$.fastSquaredDistance( v1.vector, v1.norm, v2.vector, v2.norm, 1e-6 );
return Math.abs( v1.vector( ) - v2.vector( ) );
}


private static abstract class MyFunction1<T1, R> extends AbstractFunction1<T1, R> implements Serializable
{


}


private static abstract class MyFunction2<T1, T2, R> extends AbstractFunction2<T1, T2, R> implements Serializable
{


}


private static abstract class MyFunction0<R> extends AbstractFunction0<R> implements Serializable
{


}


private static AccumulatorParam<Double> doubleAccumulatorParam = new AccumulatorParam<Double>( )
{


@Override
public Double addInPlace( Double arg0, Double arg1 )
{
return arg0 + arg1;
}


@Override
public Double zero( Double arg0 )
{
return arg0;
}


@Override
public Double addAccumulator( Double arg0, Double arg1 )
{


return arg0 + arg1;
}


};


private static SimpleKMeansModel train( RDD<Double> data, int k,
int maxIterations, int runs, String initializationMode )
{
return new SimpleKMeans( ).setK( k ).setMaxIterations( maxIterations ).setRuns( runs ).setInitializationMode( initializationMode ).run( data );
}


public static SimpleKMeansModel train( RDD<Double> data, int k,
int maxIterations )
{
return train( data, k, maxIterations, 1, K_MEANS_PARALLEL );
}
}



  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值