学习来源:日撸 Java 三百行(61-70天,决策树与集成学习)_闵帆的博客-CSDN博客
树桩分类器
树桩分类器每次只将数据分成两堆。
SimpleClassifier抽象类
这个类主要是作为树桩分类器的接口,其中的训练和分类函数都没有实现。但实现了计算正确数组、计算训练正确率和计算权重错误。
计算正确数组函数
将分类结果和实际结果做对比,如果正确为true,错误就为false。
返回结果数组,例如:[true,true,false]表示第一二个数据分类正确,第三个数据分类错误。
计算训练正确率
分类正确个数/总数据个数。
计算权重错误
将所有为分类错误的数据赋予的权重求和。如果求和结果小于10的负6次方,则返回10的负6次方。
package 日撸Java300行_61_70;
import java.util.Random;
import weka.core.Instance;
/**
* The super class of any simple classifier.
*
* @author Hui Xiao
*/
public abstract class SimpleClassifier {
/**
* The index of the current attribute.
*/
int selectedAttribute;
/**
* Weighted data.
*/
WeightedInstances weightedInstances;
/**
* The accuracy on the training set.
*/
double trainingAccuracy;
/**
* The number of classes. For binary classification it is 2.
*/
int numClasses;
/**
* The number of instances.
*/
int numInstances;
/**
* The number of conditional attributes.
*/
int numConditions;
/**
* For random number generation.
*/
Random random = new Random();
/**
******************
* The first constructor.
*
* @param paraWeightedInstances
* The given instances.
******************
*/
public SimpleClassifier(WeightedInstances paraWeightedInstances) {
weightedInstances = paraWeightedInstances;
numConditions = weightedInstances.numAttributes() - 1;
numInstances = weightedInstances.numInstances();
numClasses = weightedInstances.classAttribute().numValues();
}// Of the first constructor
/**
******************
* Train the classifier.
******************
*/
public abstract void train();
/**
******************
* Classify an instance.
*
* @param paraInstance
* The given instance.
* @return Predicted label.
******************
*/
public abstract int classify(Instance paraInstance);
/**
******************
* Which instances in the training set are correctly classified.
*
* @return The correctness array.
******************
*/
public boolean[] computeCorrectnessArray() {
boolean[] resultCorrectnessArray = new boolean[weightedInstances.numInstances()];
for (int i = 0; i < resultCorrectnessArray.length; i++) {
Instance tempInstance = weightedInstances.instance(i);
if ((int) (tempInstance.classValue()) == classify(tempInstance)) {
resultCorrectnessArray[i] = true;
} // Of if
// System.out.print("\t" + classify(tempInstance));
} // Of for i
// System.out.println();
return resultCorrectnessArray;
}// Of computeCorrectnessArray
/**
******************
* Compute the accuracy on the training set.
*
* @return The training accuracy.
******************
*/
public double computeTrainingAccuracy() {
double tempCorrect = 0;
boolean[] tempCorrectnessArray = computeCorrectnessArray();
for (int i = 0; i < tempCorrectnessArray.length; i++) {
if (tempCorrectnessArray[i]) {
tempCorrect++;
} // Of if
} // Of for i
double resultAccuracy = tempCorrect / tempCorrectnessArray.length;
return resultAccuracy;
}// Of computeTrainingAccuracy
/**
******************
* Compute the weighted error on the training set. It is at least 1e-6 to
* avoid NaN.
*
* @return The weighted error.
******************
*/
public double computeWeightedError() {
double resultError = 0;
boolean[] tempCorrectnessArray = computeCorrectnessArray();
for (int i = 0; i < tempCorrectnessArray.length; i++) {
if (!tempCorrectnessArray[i]) {
resultError += weightedInstances.getWeight(i);
} // Of if
} // Of for i
if (resultError < 1e-6) {
resultError = 1e-6;
} // Of if
return resultError;
}// Of computeWeightedError
} // Of class SimpleClassifier