8.1.3例子的java代码
package xigua;
import java.util.Arrays;
import scala.Tuple2;
/**
*
*/
public class AdaBoostExample
{
private static int dataLength = 10;
private static int numIterations = 3;
public static void main( String[] args )
{
Data[] datas = genData( );
Arrays.sort( datas, ( s1, s2 ) -> s1.x - s2.x );
for ( int i = 0; i < datas.length; i++ )
{
// System.out.println( data[i] );
}
double[] bins = genBins( datas );
double[] ws = new double[datas.length];
Arrays.fill( ws, 1.0 / datas.length );
double[] alphas = new double[numIterations];
Arrays.fill( alphas, 0.0 );
BaseFunction[] fs = new BaseFunction[numIterations];
for ( int i = 0; i < numIterations; i++ )
{
Tuple2<BaseFunction, Double> t = findBestFunction( datas, bins, ws, alphas, fs, i );
fs[i] = t._1;
alphas[i] = 0.5 * Math.log( ( 1 - t._2 ) / t._2 );
updateW( datas, ws, alphas[i], t._1( ) );
System.out.println( t._1 );
System.out.println( t._2 );
}
// fs and alphas are the final function;
// Test
for ( int i = 0; i < datas.length; i++ )
{
System.out.println( datas[i].y
+ "==="
+ testData( datas[i], fs, alphas ) );
}
System.out.println( "Done" );
}
private static int testData( Data data, BaseFunction[] fs, double[] alphas )
{
double value = 0.0;
for ( int i = 0; i < fs.length; i++ )
{
value = value + alphas[i] * fs[i].predData( data );
}
return value >= 0 ? 1 : -1;
}
private static void updateW( Data[] data, double[] ws, double alpha,
BaseFunction f )
{
double total = 0.0;
for ( int i = 0; i < data.length; i++ )
{
ws[i] = ws[i]
* Math.exp( -alpha * data[i].y * f.predData( data[i] ) );
total = total + ws[i];
}
for ( int i = 0; i < ws.length; i++ )
{
ws[i] = ws[i] / total;
}
}
private static Tuple2<BaseFunction, Double> findBestFunction( Data[] data,
double[] bins, double[] ws, double[] alphas, BaseFunction[] fs,
int numIterations )
{
double minError = Double.MAX_VALUE;
BaseFunction best = null;
for ( int i = 0; i < bins.length; i++ )
{
BaseFunction f = new BaseFunction( bins[i] );
double totalError0 = 0.0;
double totalError1 = 0.0;
for ( int j = 0; j < data.length; j++ )
{
f.setDirection( BaseFunction.FORWORAD );
if ( f.predData( data[j] ) != data[j].y )
{
totalError0 = totalError0 + ws[j] * 1;
}
f.setDirection( BaseFunction.BACKWARD );
if ( f.predData( data[j] ) != data[j].y )
{
totalError1 = totalError1 + ws[j] * 1;
}
//
}
if ( totalError0 < minError )
{
minError = totalError0;
f.setDirection( BaseFunction.FORWORAD );
best = f;
}
if ( totalError1 < minError )
{
minError = totalError1;
f.setDirection( BaseFunction.BACKWARD );
best = f;
}
}
return new Tuple2<BaseFunction, Double>( best, minError );
}
private static double[] genBins( Data[] data )
{
if ( data.length == 0 )
{
return new double[0];
}
double[] bins = new double[data.length - 1];
for ( int i = 1; i < data.length; i++ )
{
bins[i - 1] = ( data[i - 1].x + data[i].x ) / 2.0;
}
return bins;
}
private static Data[] genData( )
{
Data[] retValue = new Data[dataLength];
for ( int i = 0; i < dataLength; i++ )
{
retValue[i] = new Data( );
retValue[i].x = i;
if ( i <= 2 )
{
retValue[i].y = 1;
}
else if ( i <= 5 )
{
retValue[i].y = -1;
}
else if ( i <= 8 )
{
retValue[i].y = 1;
}
else
{
retValue[i].y = -1;
}
}
return retValue;
}
private static class Data
{
private int x;
private int y;
@Override
public String toString( )
{
return "X == " + x + "And Y == " + y;
}
}
private static class BaseFunction
{
private static int FORWORAD = 1, BACKWARD = -1;
private int direction = FORWORAD;
private double bin;
BaseFunction( )
{
this( -1 );
}
BaseFunction( double b )
{
this.bin = b;
}
int predData( Data data )
{
if ( data.x < bin )
{
return 1*direction;
}
else
{
return -1*direction;
}
}
public int getDirection( )
{
return direction;
}
public void setDirection( int direction )
{
this.direction = direction;
}
}
}