李航统计分析AdaBoost例子java代码

8.1.3例子的java代码


package xigua;

import java.util.Arrays;

import scala.Tuple2;

/**
 * 
 */

public class AdaBoostExample
{

    private static int dataLength = 10;
    private static int numIterations = 3;

    public static void main( String[] args )
    {
        Data[] datas = genData( );
        Arrays.sort( datas, ( s1, s2 ) -> s1.x - s2.x );
        for ( int i = 0; i < datas.length; i++ )
        {
            // System.out.println( data[i] );
        }

        double[] bins = genBins( datas );

        double[] ws = new double[datas.length];
        Arrays.fill( ws, 1.0 / datas.length );

        double[] alphas = new double[numIterations];
        Arrays.fill( alphas, 0.0 );

        BaseFunction[] fs = new BaseFunction[numIterations];

        for ( int i = 0; i < numIterations; i++ )
        {
            Tuple2<BaseFunction, Double> t = findBestFunction( datas, bins, ws, alphas, fs, i );
            fs[i] = t._1;
            alphas[i] = 0.5 * Math.log( ( 1 - t._2 ) / t._2 );

            updateW( datas, ws, alphas[i], t._1( ) );
            System.out.println( t._1 );
            System.out.println( t._2 );
        }
        // fs and alphas are the final function;
        // Test
        for ( int i = 0; i < datas.length; i++ )
        {
            System.out.println( datas[i].y
                    + "==="
                    + testData( datas[i], fs, alphas ) );
        }
        System.out.println( "Done" );
    }

    private static int testData( Data data, BaseFunction[] fs, double[] alphas )
    {
        double value = 0.0;
        for ( int i = 0; i < fs.length; i++ )
        {
            value = value + alphas[i] * fs[i].predData( data );
        }

        return value >= 0 ? 1 : -1;
    }

    private static void updateW( Data[] data, double[] ws, double alpha,
            BaseFunction f )
    {
        double total = 0.0;
        for ( int i = 0; i < data.length; i++ )
        {
            ws[i] = ws[i]
                    * Math.exp( -alpha * data[i].y * f.predData( data[i] ) );
            total = total + ws[i];
        }
        for ( int i = 0; i < ws.length; i++ )
        {
            ws[i] = ws[i] / total;
        }
    }

    private static Tuple2<BaseFunction, Double> findBestFunction( Data[] data,
            double[] bins, double[] ws, double[] alphas, BaseFunction[] fs,
            int numIterations )
    {
        double minError = Double.MAX_VALUE;
        BaseFunction best = null;
        for ( int i = 0; i < bins.length; i++ )
        {
            BaseFunction f = new BaseFunction( bins[i] );
            double totalError0 = 0.0;
            double totalError1 = 0.0;
            for ( int j = 0; j < data.length; j++ )
            {
                f.setDirection( BaseFunction.FORWORAD );
                if ( f.predData( data[j] ) != data[j].y )
                {
                    totalError0 = totalError0 + ws[j] * 1;
                }
                
                f.setDirection( BaseFunction.BACKWARD );
                if ( f.predData( data[j] ) != data[j].y )
                {
                    totalError1 = totalError1 + ws[j] * 1;
                }

                //

            }
            if ( totalError0 < minError )
            {
                minError = totalError0;
                f.setDirection( BaseFunction.FORWORAD );
                best = f;
            }
            if ( totalError1 < minError )
            {
                minError = totalError1;
                f.setDirection( BaseFunction.BACKWARD );
                best = f;
            }
        }
        return new Tuple2<BaseFunction, Double>( best, minError );
    }

    private static double[] genBins( Data[] data )
    {
        if ( data.length == 0 )
        {
            return new double[0];
        }
        double[] bins = new double[data.length - 1];
        for ( int i = 1; i < data.length; i++ )
        {
            bins[i - 1] = ( data[i - 1].x + data[i].x ) / 2.0;
        }
        return bins;
    }

    private static Data[] genData( )
    {
        Data[] retValue = new Data[dataLength];
        for ( int i = 0; i < dataLength; i++ )
        {
            retValue[i] = new Data( );
            retValue[i].x = i;
            if ( i <= 2 )
            {
                retValue[i].y = 1;
            }
            else if ( i <= 5 )
            {
                retValue[i].y = -1;
            }
            else if ( i <= 8 )
            {
                retValue[i].y = 1;
            }
            else
            {
                retValue[i].y = -1;
            }

        }
        return retValue;
    }

    private static class Data
    {

        private int x;
        private int y;

        @Override
        public String toString( )
        {
            return "X == " + x + "And Y == " + y;
        }

    }

    private static class BaseFunction
    {
        private static int FORWORAD = 1, BACKWARD = -1;
        private int direction = FORWORAD;
        private double bin;

        BaseFunction( )
        {
            this( -1 );
        }

        BaseFunction( double b )
        {
            this.bin = b;
        }
        
        int predData( Data data )
        {
            if ( data.x < bin )
            {
                return 1*direction;
            }
            else
            {
                return -1*direction;
            }
        }

        
        public int getDirection( )
        {
            return direction;
        }

        
        public void setDirection( int direction )
        {
            this.direction = direction;
        }
    }
}
 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值