用kmeans split 连续值

Sprk 决策树对于连续值的处理使用sort方法,MinMax和ApproxHist方法都没有实现。其实可以使用KMeans 做split。这方面的论文也很多。

测试的代码如下,

findSplitsForContinuousFeature 的方法就是SPARK DecisionTree 的方法。

需要改进的

1:Kmeans 计算后的centers 可以直接用作split,是为了一致才寻找最近的点。

2:有的论文中有一个minimum_records_per_cluster,如果遇到这种情况 split number需要减少,重新计算。

3:因为是一维的Vector和原因2,其实可以不使用Spark 的KMean,直接计算效率更高。

4:在实际中list会很大,应该用RDD处理。


import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;


import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vectors;


import breeze.linalg.min;
import scala.Tuple2;


/**
 * 
 */


public class SortSplitTest
{


public static void main( String[] args )
{
List<Tuple2<Double, Integer>> list = new ArrayList<Tuple2<Double, Integer>>( );

List<Double> dList = new ArrayList<Double>();

genrateTestData( dList, 1.0, 5 );
genrateTestData( dList, 2.0, 3 );
genrateTestData( dList, 3.0, 5 );
genrateTestData( dList, 4.0, 8 );


Double[] values1 = findSplitsForContinuousFeature( dList, 2 );
Double[] values2 = testKMeansSplit( dList, 2 );
for ( double d : values2 )
{
System.out.println( d );
}
}

private static Double[] testKMeansSplit(List<Double> list, int numSplits)
{
SparkConf conf = new SparkConf( ).setAppName( "KMeans split" ).setMaster( "local" );
JavaSparkContext ctx = new JavaSparkContext( conf );

JavaRDD<Double> rdd = ctx.parallelize( list );

KMeansModel model = KMeans.train( rdd.map( s->Vectors.dense( new double[]{s} ) ).rdd( ), numSplits, 10 );
List<Double> centers = Arrays.stream( model.clusterCenters( )).map( s->s.toArray( )[0] ).collect( Collectors.toList( ) );

List<Double> retValue = new ArrayList<Double>();
for (double d:centers)
{
double minValue = Double.MAX_VALUE;
double value = 0.0;
for (double dd:list)
{
double distance = Math.abs( dd - d );

if (distance < minValue)
{
value = dd;
minValue = distance;
}
}

retValue.add( value );
}
retValue.sort( (s1, s2)->s1>s2?1:-1 );
ctx.stop( );
return retValue.toArray( new Double[retValue.size( )]);
}

private static void genrateTestData(List<Double> dList, double value, int times)
{
for (int i=0; i<times; i++)
{
dList.add( value );
}
}


private static Double[] findSplitsForContinuousFeature(
List<Double> featureSamples, int numSplits )
{
Double[] splits = null;
// int numSplits = metadata.numSplits( featureIndex );
Map<Double, Integer> valueCountMap = new HashMap<Double, Integer>( );
for ( int i = 0; i < featureSamples.size( ); i++ )
{
Integer value = valueCountMap.get( featureSamples.get( i ) );
if ( value == null )
{
value = 0;
}
value = value + 1;
valueCountMap.put( featureSamples.get( i ), value );
}


List<Tuple2<Double, Integer>> valueCounts = valueCountMap.entrySet( ).stream( ).sorted( new Comparator<Entry<Double, Integer>>( )
{


@Override
public int compare( Entry<Double, Integer> t1,
Entry<Double, Integer> t2 )
{
return t1.getKey( ) - t2.getKey( ) >= 0 ? 1 : -1;
}
} ).map( s -> new Tuple2<Double, Integer>( s.getKey( ), s.getValue( ) ) ).collect( Collectors.toList( ) );


int possibleSplits = valueCounts.size( );


double stride = ( (Integer) featureSamples.size( ) ).doubleValue( )
/ ( numSplits + 1 );
List<Double> splitsBuilder = new ArrayList<Double>( );
int index = 1;
int currentCount = valueCounts.get( 0 )._2;


double targetCount = stride;
while ( index < valueCounts.size( ) )
{
int previousCount = currentCount;
currentCount += valueCounts.get( index )._2;
double previousGap = Math.abs( previousCount - targetCount );
double currentGap = Math.abs( currentCount - targetCount );


if ( previousGap < currentGap )
{
splitsBuilder.add( valueCounts.get( index - 1 )._1( ) );
targetCount += stride;
}
index += 1;
}
splits = splitsBuilder.toArray( new Double[splitsBuilder.size( )] );


return splits;
}


}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值