day65

package machinelearning.adaBoosting;

import java.io.FileReader;
import weka.core.Instance;
import weka.core.Instances;

/**
 * ******************************************
 * The booster which ensembles base classifiers.
 *
 * @author Michelle Min MitchelleMin@163.com
 * @date 2021-07-28
 * ******************************************
 */
public class Booster {

    /**
     * Classifiers.
     */
    SimpleClassifier[] classifiers;

    /**
     * Number of classifiers;
     */
    int numClassifiers;

    /**
     * Whether or not stop after the training error is 0.
     */
    boolean stopAfterConverge = false;

    /**
     * The weights of classifiers.
     */
    double[] classifierWeights;

    /**
     * The training data.
     */
    Instances trainingData;

    /**
     * The testing data.
     */
    Instances testingData;

    /**
     ********************
     * The first constructor. The testing set is the same as the training set.
     *
     * @param paraTrainingFilename
     *            The data filename.
     ********************
     */
    public Booster(String paraTrainingFilename){
        //Step 1.Read training set.
        try{
            FileReader tempFileReader  = new FileReader(paraTrainingFilename);
            trainingData = new Instances(tempFileReader);
            tempFileReader.close();
        } catch (Exception ee){
            System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + ee);
            System.exit(0);
        }//of try

        // Step 2. Set the last attribute as the class index.
        trainingData.setClassIndex(trainingData.numAttributes() - 1);

        // Step 3. The testing data is the same as the training data.
        testingData = trainingData;

        stopAfterConverge = true;

        System.out.println("*********Data*********\r\n" + trainingData);
    }//of the first constructor

    /**
     ********************
     * Set the number of base classifier, and allocate space for them.
     *
     * @param paraNumBaseClassifier
     *            The number of base classifier.
     ********************
     */
    public void setNumClassifiers(int paraNumBaseClassifier){
        numClassifiers = paraNumBaseClassifier;

        // Step 1. Allocate space (only reference) for classifiers
        classifiers = new SimpleClassifier[numClassifiers];

        //Step 2. Initialize classifier weights.
        classifierWeights = new double[numClassifiers];
    }//of setNumClassifiers

    /**
     ********************
     * Train the booster.
     ********************
     */
    public void train(){
        //Step 1. Initialize.
        WeightedInstances tempWeightedInstances = null;
        double tempError;
        numClassifiers = 0;

        //Step 2. Build other classifiers.
        for (int i = 0; i < classifiers.length; i++) {
            // Step 2.1 Key code: Construct or adjust the weightedInstances
            if (i == 0) {
                tempWeightedInstances = new WeightedInstances(trainingData);
            }else{
                //adjust the weights of the data.
                tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(),
                        classifierWeights[i - 1]);
            }//of if

            // Step 2.2 Train the next classifier.
            classifiers[i] = new StumpClassifier(tempWeightedInstances);
            classifiers[i].train();

            tempError = classifiers[i].computeWeightedError();

            //Key code: Set the classifier weight.
            classifierWeights[i] = 0.5 * Math.log(1 / tempError -1);
            if (classifierWeights[i]  <  1e-6) {
                classifierWeights[i] = 0;
            }//of if

            System.out.println("Classifier #" + i + ", weighted error = " + tempError + ", weight = "
                    + classifierWeights[i] +"\r\n");

            numClassifiers++;

            //The accuracy is enough.
            if (stopAfterConverge) {
                double tempTrainingAccuracy = computeTrainingAccuracy();
                System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
                if (tempTrainingAccuracy > 0.999999) {
                    System.out.println("Stop at the round: " + i + " due to converge. \r\n");
                    break;
                }//of if
            }//of if
        }//of for i
    }//of train

    /**
     ********************
     * Classify an instance.
     *
     * @param paraInstance
     *            The given instance.
     * @return The predicted label.
     ********************
     */
    public int classify (Instance paraInstance){
        double[] tempLabellsCountArray = new double[trainingData.classAttribute().numValues()];
        for (int i = 0; i < numClassifiers; i++) {
            int tempLabel = classifiers[i].classify(paraInstance);
            tempLabellsCountArray[tempLabel] += classifierWeights[i];
        }//of for i

        int resultLabel = -1;
        double tempMax = -1;
        for (int i = 0; i < tempLabellsCountArray.length; i++) {
            if (tempMax < tempLabellsCountArray[i]) {
                tempMax = tempLabellsCountArray[i];
                resultLabel = i;
            }//of if
        }//of for i

        return resultLabel;
    }//of classify

    /**
     ********************
     * Test the booster on the training data.
     *
     * @return The classification accuracy.
     ********************
     */
    public double test(){
        System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");

        return test(testingData);
    }//of test

    /**
     ********************
     * Test the booster.
     *
     * @param paraInstances
     *            The testing set.
     * @return The classification accuracy.
     ********************
     */
    public double test(Instances paraInstances){
        double tempCorrect = 0;
        paraInstances.setClassIndex(paraInstances.numAttributes() - 1);

        for (int i = 0; i < paraInstances.numInstances(); i++) {
            Instance tempInstance = paraInstances.instance(i);
            if (classify(tempInstance) == (int)tempInstance.classValue()) {
                tempCorrect++;
            }//of if
        }//of for i

        double resultAccuracy = tempCorrect / paraInstances.numInstances();
        System.out.println("The accuracy is: " + resultAccuracy);

        return resultAccuracy;
    }//of test

    /**
     ********************
     * Compute the training accuracy of the booster. It is not weighted.
     *
     * @return The training accuracy.
     ********************
     */
    public double computeTrainingAccuracy(){
        double tempCorrect = 0;

        for (int i = 0; i < trainingData.numInstances(); i++) {
            if (classify(trainingData.instance(i)) == (int) trainingData.instance(i).classValue()) {
                tempCorrect++;
            }//of if
        }//of for i

        double tempAccuracy = tempCorrect / trainingData.numInstances();

        return tempAccuracy;
    }//of computeTrainingAccuracy

    /**
     ********************
     * For integration test.
     *
     * @param args
     *            Not provided.
     ********************
     */
    public static void main(String[] args){
        System.out.println("Starting AdaBoosting...");
        Booster tempBooster = new Booster("D:/mitchelles/data/iris.arff");

        tempBooster.setNumClassifiers(100);
        tempBooster.train();

        System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuracy());
        tempBooster.test();
    }//of main
}//of Booster

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值