Sprk 决策树对于连续值的处理使用sort方法,MinMax和ApproxHist方法都没有实现。其实可以使用KMeans 做split。这方面的论文也很多。
测试的代码如下,
findSplitsForContinuousFeature 的方法就是SPARK DecisionTree 的方法。
需要改进的
1:Kmeans 计算后的centers 可以直接用作split,是为了一致才寻找最近的点。
2:有的论文中有一个minimum_records_per_cluster,如果遇到这种情况 split number需要减少,重新计算。
3:因为是一维的Vector和原因2,其实可以不使用Spark 的KMean,直接计算效率更高。
4:在实际中list会很大,应该用RDD处理。
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vectors;
import breeze.linalg.min;
import scala.Tuple2;
/**
*
*/
public class SortSplitTest
{
public static void main( String[] args )
{
List<Tuple2<Double, Integer>> list = new ArrayList<Tuple2<Double, Integer>>( );
List<Double> dList = new ArrayList<Double>();
genrateTestData( dList, 1.0, 5 );
genrateTestData( dList, 2.0, 3 );
genrateTestData( dList, 3.0, 5 );
genrateTestData( dList, 4.0, 8 );
Double[] values1 = findSplitsForContinuousFeature( dList, 2 );
Double[] values2 = testKMeansSplit( dList, 2 );
for ( double d : values2 )
{
System.out.println( d );
}
}
private static Double[] testKMeansSplit(List<Double> list, int numSplits)
{
SparkConf conf = new SparkConf( ).setAppName( "KMeans split" ).setMaster( "local" );
JavaSparkContext ctx = new JavaSparkContext( conf );
JavaRDD<Double> rdd = ctx.parallelize( list );
KMeansModel model = KMeans.train( rdd.map( s->Vectors.dense( new double[]{s} ) ).rdd( ), numSplits, 10 );
List<Double> centers = Arrays.stream( model.clusterCenters( )).map( s->s.toArray( )[0] ).collect( Collectors.toList( ) );
List<Double> retValue = new ArrayList<Double>();
for (double d:centers)
{
double minValue = Double.MAX_VALUE;
double value = 0.0;
for (double dd:list)
{
double distance = Math.abs( dd - d );
if (distance < minValue)
{
value = dd;
minValue = distance;
}
}
retValue.add( value );
}
retValue.sort( (s1, s2)->s1>s2?1:-1 );
ctx.stop( );
return retValue.toArray( new Double[retValue.size( )]);
}
private static void genrateTestData(List<Double> dList, double value, int times)
{
for (int i=0; i<times; i++)
{
dList.add( value );
}
}
private static Double[] findSplitsForContinuousFeature(
List<Double> featureSamples, int numSplits )
{
Double[] splits = null;
// int numSplits = metadata.numSplits( featureIndex );
Map<Double, Integer> valueCountMap = new HashMap<Double, Integer>( );
for ( int i = 0; i < featureSamples.size( ); i++ )
{
Integer value = valueCountMap.get( featureSamples.get( i ) );
if ( value == null )
{
value = 0;
}
value = value + 1;
valueCountMap.put( featureSamples.get( i ), value );
}
List<Tuple2<Double, Integer>> valueCounts = valueCountMap.entrySet( ).stream( ).sorted( new Comparator<Entry<Double, Integer>>( )
{
@Override
public int compare( Entry<Double, Integer> t1,
Entry<Double, Integer> t2 )
{
return t1.getKey( ) - t2.getKey( ) >= 0 ? 1 : -1;
}
} ).map( s -> new Tuple2<Double, Integer>( s.getKey( ), s.getValue( ) ) ).collect( Collectors.toList( ) );
int possibleSplits = valueCounts.size( );
double stride = ( (Integer) featureSamples.size( ) ).doubleValue( )
/ ( numSplits + 1 );
List<Double> splitsBuilder = new ArrayList<Double>( );
int index = 1;
int currentCount = valueCounts.get( 0 )._2;
double targetCount = stride;
while ( index < valueCounts.size( ) )
{
int previousCount = currentCount;
currentCount += valueCounts.get( index )._2;
double previousGap = Math.abs( previousCount - targetCount );
double currentGap = Math.abs( currentCount - targetCount );
if ( previousGap < currentGap )
{
splitsBuilder.add( valueCounts.get( index - 1 )._1( ) );
targetCount += stride;
}
index += 1;
}
splits = splitsBuilder.toArray( new Double[splitsBuilder.size( )] );
return splits;
}
}