李航统计分析AdaBoost 提升书例子java代码

该博客展示了如何用Java代码实现AdaBoost算法,通过一个具体的例子详细解释了算法的迭代过程,包括数据处理、基函数的选择和更新、误差计算等关键步骤。
摘要由CSDN通过智能技术生成

Java代码

 


package xigua;

import java.util.Arrays;

import scala.Tuple2;

/**
 * 
 */

public class AdaBoostSquaredExample
{

    private static int dataLength = 10;
    private static int numIterations = 6;

    private static double[] SAMPLE = new double[]{
            5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05
    };

    public static void main( String[] args )
    {
        Data[] datas = genData( );
        Arrays.sort( datas, ( s1, s2 ) -> s1.x - s2.x );

        for ( int i = 0; i < datas.length; i++ )
        {
            // System.out.println( data[i] );
        }

        double[] bins = genBins( datas );

        double[] ws = new double[datas.length];
        Arrays.fill( ws, 1.0 / datas.length );

        double[] alphas = new double[numIterations];
        Arrays.fill( alphas, 0.0 );

        BaseFunction[] fs = new BaseFunction[numIterations];

        for ( int i = 0; i < numIterations; i++ )
        {
            Tuple2<BaseFunction, Double> t = findBestFunction( datas, bins );
            fs[i] = t._1;

            updateData( datas, t._1( ) );

        }
        // Test
        double totalError = 0.0;
        for ( int i = 0; i < datas.length; i++ )
        {
            double value = testData( datas[i], fs );
            System.out.println( SAMPLE[i] + "===" + value );
            double temp = SAMPLE[i] - value;
            totalError = totalError + temp * temp;
        }
        System.out.println( "TotalError == " + totalError );
        System.out.println( "Done" );
    }

    private static double testData( Data data, BaseFunction[] fs )
    {
        double value = 0.0;
        for ( int i = 0; i < fs.length; i++ )
        {
            if ( data.x <= fs[i].bin )
            {
                value = value + fs[i].leftMean;
            }
            else
            {
                value = value + fs[i].rightMean;
            }
        }

        return value;
    }

    private static void updateData( Data[] data, BaseFunction f )
    {
        double total = 0.0;
        for ( int i = 0; i < data.length; i++ )
        {
            if ( data[i].x <= f.bin )
            {
                data[i].y = data[i].y - f.leftMean;
            }
            else
            {
                data[i].y = data[i].y - f.rightMean;
            }
        }

    }

    private static Tuple2<BaseFunction, Double> findBestFunction( Data[] data,
            double[] bins )
    {
        double minSq = Double.MAX_VALUE;
        BaseFunction best = null;
        for ( int i = 0; i < bins.length; i++ )
        {
            BaseFunction f = new BaseFunction( bins[i] );
            double leftTotal = 0.0;
            int leftCount = 0;
            double rightTotal = 0.0;
            int rightCount = 0;
            for ( int j = 0; j < data.length; j++ )
            {
                if ( data[j].x < bins[i] )
                {
                    leftTotal = leftTotal + data[j].y;
                    leftCount++;
                }
                else
                {
                    rightTotal = rightTotal + data[j].y;
                    rightCount++;
                }

            }
            double leftMean = leftTotal / leftCount;
            double rightMean = rightTotal / rightCount;

            double totalSq = 0.0;
            for ( int j = 0; j < data.length; j++ )
            {
                if ( data[j].x < bins[i] )
                {
                    double value = data[j].y - leftMean;
                    totalSq = totalSq + value * value;
                }
                else
                {
                    double value = data[j].y - rightMean;
                    totalSq = totalSq + value * value;
                }
            }

            if ( totalSq < minSq )
            {
                best = f;
                minSq = totalSq;
                best.leftMean = leftMean;
                best.rightMean = rightMean;
            }

        }
        return new Tuple2<BaseFunction, Double>( best, minSq );
    }

    private static double[] genBins( Data[] data )
    {
        if ( data.length == 0 )
        {
            return new double[0];
        }
        double[] bins = new double[data.length - 1];
        for ( int i = 1; i < data.length; i++ )
        {
            bins[i - 1] = ( data[i - 1].x + data[i].x ) / 2.0;
        }
        return bins;
    }

    private static Data[] genData( )
    {
        Data[] retValue = new Data[dataLength];
        for ( int i = 0; i < dataLength; i++ )
        {
            retValue[i] = new Data( );
            retValue[i].x = i + 1;
            retValue[i].y = SAMPLE[i];

        }
        return retValue;
    }

    private static class Data
    {

        private int x;
        private double y;

        @Override
        public String toString( )
        {
            return "X == " + x + "And Y == " + y;
        }

    }

    private static class BaseFunction
    {

        private double bin;
        private double leftMean, rightMean;

        BaseFunction( )
        {
            this( -1 );
        }

        BaseFunction( double b )
        {
            this.bin = b;
        }

    }
}
 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值