package machinelearning.adaBoosting;
import java.io.FileReader;
import weka.core.Instance;
import weka.core.Instances;
/**
* ******************************************
* The booster which ensembles base classifiers.
*
* @author Michelle Min MitchelleMin@163.com
* @date 2021-07-28
* ******************************************
*/
public class Booster {
/**
* Classifiers.
*/
SimpleClassifier[] classifiers;
/**
* Number of classifiers;
*/
int numClassifiers;
/**
* Whether or not stop after the training error is 0.
*/
boolean stopAfterConverge = false;
/**
* The weights of classifiers.
*/
double[] classifierWeights;
/**
* The training data.
*/
Instances trainingData;
/**
* The testing data.
*/
Instances testingData;
/**
********************
* The first constructor. The testing set is the same as the training set.
*
* @param paraTrainingFilename
* The data filename.
********************
*/
public Booster(String paraTrainingFilename){
//Step 1.Read training set.
try{
FileReader tempFileReader = new FileReader(paraTrainingFilename);
trainingData = new Instances(tempFileReader);
tempFileReader.close();
} catch (Exception ee){
System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + ee);
System.exit(0);
}//of try
// Step 2. Set the last attribute as the class index.
trainingData.setClassIndex(trainingData.numAttributes() - 1);
// Step 3. The testing data is the same as the training data.
testingData = trainingData;
stopAfterConverge = true;
System.out.println("*********Data*********\r\n" + trainingData);
}//of the first constructor
/**
********************
* Set the number of base classifier, and allocate space for them.
*
* @param paraNumBaseClassifier
* The number of base classifier.
********************
*/
public void setNumClassifiers(int paraNumBaseClassifier){
numClassifiers = paraNumBaseClassifier;
// Step 1. Allocate space (only reference) for classifiers
classifiers = new SimpleClassifier[numClassifiers];
//Step 2. Initialize classifier weights.
classifierWeights = new double[numClassifiers];
}//of setNumClassifiers
/**
********************
* Train the booster.
********************
*/
public void train(){
//Step 1. Initialize.
WeightedInstances tempWeightedInstances = null;
double tempError;
numClassifiers = 0;
//Step 2. Build other classifiers.
for (int i = 0; i < classifiers.length; i++) {
// Step 2.1 Key code: Construct or adjust the weightedInstances
if (i == 0) {
tempWeightedInstances = new WeightedInstances(trainingData);
}else{
//adjust the weights of the data.
tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(),
classifierWeights[i - 1]);
}//of if
// Step 2.2 Train the next classifier.
classifiers[i] = new StumpClassifier(tempWeightedInstances);
classifiers[i].train();
tempError = classifiers[i].computeWeightedError();
//Key code: Set the classifier weight.
classifierWeights[i] = 0.5 * Math.log(1 / tempError -1);
if (classifierWeights[i] < 1e-6) {
classifierWeights[i] = 0;
}//of if
System.out.println("Classifier #" + i + ", weighted error = " + tempError + ", weight = "
+ classifierWeights[i] +"\r\n");
numClassifiers++;
//The accuracy is enough.
if (stopAfterConverge) {
double tempTrainingAccuracy = computeTrainingAccuracy();
System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
if (tempTrainingAccuracy > 0.999999) {
System.out.println("Stop at the round: " + i + " due to converge. \r\n");
break;
}//of if
}//of if
}//of for i
}//of train
/**
********************
* Classify an instance.
*
* @param paraInstance
* The given instance.
* @return The predicted label.
********************
*/
public int classify (Instance paraInstance){
double[] tempLabellsCountArray = new double[trainingData.classAttribute().numValues()];
for (int i = 0; i < numClassifiers; i++) {
int tempLabel = classifiers[i].classify(paraInstance);
tempLabellsCountArray[tempLabel] += classifierWeights[i];
}//of for i
int resultLabel = -1;
double tempMax = -1;
for (int i = 0; i < tempLabellsCountArray.length; i++) {
if (tempMax < tempLabellsCountArray[i]) {
tempMax = tempLabellsCountArray[i];
resultLabel = i;
}//of if
}//of for i
return resultLabel;
}//of classify
/**
********************
* Test the booster on the training data.
*
* @return The classification accuracy.
********************
*/
public double test(){
System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");
return test(testingData);
}//of test
/**
********************
* Test the booster.
*
* @param paraInstances
* The testing set.
* @return The classification accuracy.
********************
*/
public double test(Instances paraInstances){
double tempCorrect = 0;
paraInstances.setClassIndex(paraInstances.numAttributes() - 1);
for (int i = 0; i < paraInstances.numInstances(); i++) {
Instance tempInstance = paraInstances.instance(i);
if (classify(tempInstance) == (int)tempInstance.classValue()) {
tempCorrect++;
}//of if
}//of for i
double resultAccuracy = tempCorrect / paraInstances.numInstances();
System.out.println("The accuracy is: " + resultAccuracy);
return resultAccuracy;
}//of test
/**
********************
* Compute the training accuracy of the booster. It is not weighted.
*
* @return The training accuracy.
********************
*/
public double computeTrainingAccuracy(){
double tempCorrect = 0;
for (int i = 0; i < trainingData.numInstances(); i++) {
if (classify(trainingData.instance(i)) == (int) trainingData.instance(i).classValue()) {
tempCorrect++;
}//of if
}//of for i
double tempAccuracy = tempCorrect / trainingData.numInstances();
return tempAccuracy;
}//of computeTrainingAccuracy
/**
********************
* For integration test.
*
* @param args
* Not provided.
********************
*/
public static void main(String[] args){
System.out.println("Starting AdaBoosting...");
Booster tempBooster = new Booster("D:/mitchelles/data/iris.arff");
tempBooster.setNumClassifiers(100);
tempBooster.train();
System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuracy());
tempBooster.test();
}//of main
}//of Booster
08-30
08-30