一、成员变量、构造函数与基础初始化函数一览
/**
* Classifiers.
*/
SimpleClassifier[] classifiers;
/**
* Number of classifiers.
*/
int numClassifiers;
/**
* Whether or not stop after the training error is 0.
*/
boolean stopAfterConverge = false;
/**
* The weights of classifiers.
*/
double[] classifierWeights;
/**
* The training data.
*/
Instances trainingData;
/**
* The testing data.
*/
Instances testingData;
- classifiers 建立了一个分类器指针(引用)数组,用于存放此集成器需要的所有基分类器
- numClassifiers 基分类器数目,值上等于classifiers.length()
- stopAfterConverge 分类器是否可在识别度足够高的时候提前结束分类器迭代(默认设置为1)
- classifierWeights 每个分类器的权值对应的权值数组
- 训练集与测试集
二、 串行设计(分类器串行迭代)思想与代码
继承器的核心思想在于通过单个树桩分类的分类不足来进行纠正权值分布,并且影响下一个树桩分类器的分类过程,从而保证每次分布并不是完全的“ 平均 ”二分,而是有权值影响偏移与错误发生更多的标签,不断纠正错误的可能。
集成过程详见上图,最初我们可以通过基础的权值数组weightArray(所有数据行的权都一样)进行一次基础的树桩分类,得到两个可用于集成的信息:
- 分类器中自我测试下得到分类错误的权值总和
- 每一行测试是否正确的布尔数组correctnessArray
继续,定义分类器的权值计算公式:
通过这个公式,我们可以进一步计算出当前分类器的权。同时可以了解到,当一个分类器错误的权重越大,那么本身这个分类器的权就越小
然后,进一步确立一个定义:之前公式1中定义的影响因子
这个定义成功将上一个分类器的结果同下一个分类器串联了起来,结合公式2的红色总结,再结合公式1中指数函数左右半轴的速度差异性关系,可以得出结论:上一个分类器错误得越多,那么它的权就越小,他对于之后分类器中每个数据行的权值增幅影响也就越小。
总的来说,通过上一个分类计算得到的总权从而影响下一个分类器的权值分配数组,从而影响下一个分类的模型学习过程。
随着这个过程的不断迭代,我们集成器中承载的分类器越来越多,按照AdaBoost的理论,这个时候整体的分类的准确性会不断提高。这时我们就要考虑一个问题,什么时候收敛?下列有几个策略:
数目优先原则:我们为集成器设置一个分类器上限,不管当前的集成器的识别率达到到了何种程度,只要数目达到我们要求的上限,那么就自动停下来,确定当前的集成器为最终集成器,不再迭代分类器进行学习。(若前几个少量的分类器就达标了,后续的计算只会浪费开销)
识别率优先原则:我们不在意设置多少分类器,只是在意集成器是否达到我们所要求的识别率阈值。一旦达到了阈值便停止当前分类。(若设置的识别率过高,可能导致学习过程过长,甚至可能永远无法实现)
混合:顾名思义,设置两个阈值,一个是分类器的数目阈值,一个识别率阈值,任何一个达标便结束集成器构建。
我们代码用的是第三个,这个靠谱一些,能避免前两者的缺陷。
但是这里引入了一个关于集成器的一个非常有意思的讨论话题,就是假如说,若我的集成器通过此分类器的组合已经得到了一个相对100%的识别率,那么我们还学习吗?其实这个时候继续学习是可行的,继续学习能进一步扩展集成器的稳定性,保证了下一次类似数据也能保证客观的识别率。(我们会在最后用数据证明这句话的!)
综上,我们可以给出实现代码(细节就不再赘述,可结合上面的图深入理解代码):
/**
******************
* Train the booster.
*
* @see algorithm.StumpClassifier#train()
******************
*/
public void train() {
// Step 1. Initialize.
WeightedInstances tempWeightedInstances = null;
double tempError;
numClassifiers = 0;
// Step 2. Build other classifiers.
for (int i = 0; i < classifiers.length; i++) {
// Step 2.1 Key code: Construct or adjust the weightedInstances
if (i == 0) {
tempWeightedInstances = new WeightedInstances(trainingData);
} else {
// Adjust the weights of the data.
tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(),
classifierWeights[i - 1]);
} // Of if
// Step 2.2 Train the next classifier.
classifiers[i] = new StumpClassifier(tempWeightedInstances);
classifiers[i].train();
tempError = classifiers[i].computeWeightedError();
// Key code: Set the classifier weight.
classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
if (classifierWeights[i] < 1e-6) {
classifierWeights[i] = 0;
} // Of if
System.out.println("Classifier #" + i + " , weighted error = " + tempError + ", weight = "
+ classifierWeights[i] + "\r\n");
numClassifiers++;
// The accuracy is enough.
if (stopAfterConverge) {
double tempTrainingAccuracy = computeTrainingAccuray();
System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
if (tempTrainingAccuracy > 0.999999) {
System.out.println("Stop at the round: " + i + " due to converge.\r\n");
break;
} // Of if
} // Of if
} // Of for i
}// Of train
三、并行设计(分类与准确度)思想与代码)
要明白集成器是如何计算准确度的,我们首先就要了解到集成器对于任何一条数据进行分类的策略,简单可以用下面个图表示:
首先我们学习得到了一个具有n个桩分类器的集成器:
然后已知一个测试用的数据行testingInstance,将这个数据行分别同每个树桩分类器进行测试,测试得到了一个预测的标签。
现有一个名为classifierWeights的一维结构,我们将用其作为投票用的一个桶结构。testingInstance基于每个树桩分类器分类得到一个预测标签,鉴定有一个预测标签ID的值为k,便对结构classifierWeights[k]进行加权,而加的权值就是当前分类器的权。
最后判断classifierWeights的最大值并返回其下标,即投票选出权值统计最大的标签。
这个过程的代码描述如下:
/**
******************
* Classify an instance.
*
* @param paraInstance
* The given instance.
* @return The predicted label.
******************
*/
public int classify(Instance paraInstance) {
double[] tempLabelsCountArray = new double[trainingData.classAttribute().numValues()];
for (int i = 0; i < numClassifiers; i++) {
int tempLabel = classifiers[i].classify(paraInstance);
tempLabelsCountArray[tempLabel] += classifierWeights[i];
} // Of for i
int resultLabel = -1;
double tempMax = -1;
for (int i = 0; i < tempLabelsCountArray.length; i++) {
if (tempMax < tempLabelsCountArray[i]) {
tempMax = tempLabelsCountArray[i];
resultLabel = i;
} // Of if
} // Of for
return resultLabel;
}// Of classify
完整代码:
package machinelearning.adaboosting;
import java.io.FileReader;
import weka.core.Instance;
import weka.core.Instances;
/**
* The booster which ensembles base classifiers.
*
* @author Rui Chen 1369097405@qq.com.
*/
public class Booster {
/**
* Classifiers.
*/
SimpleClassifier[] classifiers;
/**
* Number of classifiers.
*/
int numClassifiers;
/**
* Whether or not stop after the training error is 0.
*/
boolean stopAfterConverge = false;
/**
* The weights of classifiers.
*/
double[] classifierWeights;
/**
* The training data.
*/
Instances trainingData;
/**
* The testing data.
*/
Instances testingData;
/**
******************
* The first constructor. The testing set is the same as the training set.
*
* @param paraTrainingFilename
* The data filename.
******************
*/
public Booster(String paraTrainingFilename) {
// Step 1. Read training set.
try {
FileReader tempFileReader = new FileReader(paraTrainingFilename);
trainingData = new Instances(tempFileReader);
tempFileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + ee);
System.exit(0);
} // Of try
// Step 2. Set the last attribute as the class index.
trainingData.setClassIndex(trainingData.numAttributes() - 1);
// Step 3. The testing data is the same as the training data.
testingData = trainingData;
stopAfterConverge = true;
System.out.println("****************Data**********\r\n" + trainingData);
}// Of the first constructor
/**
******************
* Set the number of base classifier, and allocate space for them.
*
* @param paraNumBaseClassifiers
* The number of base classifier.
******************
*/
public void setNumBaseClassifiers(int paraNumBaseClassifiers) {
numClassifiers = paraNumBaseClassifiers;
// Step 1. Allocate space (only reference) for classifiers
classifiers = new SimpleClassifier[numClassifiers];
// Step 2. Initialize classifier weights.
classifierWeights = new double[numClassifiers];
}// Of setNumBaseClassifiers
/**
******************
* Train the booster.
*
* @see algorithm.StumpClassifier#train()
******************
*/
public void train() {
// Step 1. Initialize.
WeightedInstances tempWeightedInstances = null;
double tempError;
numClassifiers = 0;
// Step 2. Build other classifiers.
for (int i = 0; i < classifiers.length; i++) {
// Step 2.1 Key code: Construct or adjust the weightedInstances
if (i == 0) {
tempWeightedInstances = new WeightedInstances(trainingData);
} else {
// Adjust the weights of the data.
tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(),
classifierWeights[i - 1]);
} // Of if
// Step 2.2 Train the next classifier.
classifiers[i] = new StumpClassifier(tempWeightedInstances);
classifiers[i].train();
tempError = classifiers[i].computeWeightedError();
// Key code: Set the classifier weight.
classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
if (classifierWeights[i] < 1e-6) {
classifierWeights[i] = 0;
} // Of if
System.out.println("Classifier #" + i + " , weighted error = " + tempError + ", weight = "
+ classifierWeights[i] + "\r\n");
numClassifiers++;
// The accuracy is enough.
if (stopAfterConverge) {
double tempTrainingAccuracy = computeTrainingAccuray();
System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
if (tempTrainingAccuracy > 0.999999) {
System.out.println("Stop at the round: " + i + " due to converge.\r\n");
break;
} // Of if
} // Of if
} // Of for i
}// Of train
/**
******************
* Classify an instance.
*
* @param paraInstance
* The given instance.
* @return The predicted label.
******************
*/
public int classify(Instance paraInstance) {
double[] tempLabelsCountArray = new double[trainingData.classAttribute().numValues()];
for (int i = 0; i < numClassifiers; i++) {
int tempLabel = classifiers[i].classify(paraInstance);
tempLabelsCountArray[tempLabel] += classifierWeights[i];
} // Of for i
int resultLabel = -1;
double tempMax = -1;
for (int i = 0; i < tempLabelsCountArray.length; i++) {
if (tempMax < tempLabelsCountArray[i]) {
tempMax = tempLabelsCountArray[i];
resultLabel = i;
} // Of if
} // Of for
return resultLabel;
}// Of classify
/**
******************
* Test the booster on the training data.
*
* @return The classification accuracy.
******************
*/
public double test() {
System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");
return test(testingData);
}// Of test
/**
******************
* Test the booster.
*
* @param paraInstances
* The testing set.
* @return The classification accuracy.
******************
*/
public double test(Instances paraInstances) {
double tempCorrect = 0;
paraInstances.setClassIndex(paraInstances.numAttributes() - 1);
for (int i = 0; i < paraInstances.numInstances(); i++) {
Instance tempInstance = paraInstances.instance(i);
if (classify(tempInstance) == (int) tempInstance.classValue()) {
tempCorrect++;
} // Of if
} // Of for i
double resultAccuracy = tempCorrect / paraInstances.numInstances();
System.out.println("The accuracy is: " + resultAccuracy);
return resultAccuracy;
} // Of test
/**
******************
* Compute the training accuracy of the booster. It is not weighted.
*
* @return The training accuracy.
******************
*/
public double computeTrainingAccuray() {
double tempCorrect = 0;
for (int i = 0; i < trainingData.numInstances(); i++) {
if (classify(trainingData.instance(i)) == (int) trainingData.instance(i).classValue()) {
tempCorrect++;
} // Of if
} // Of for i
double tempAccuracy = tempCorrect / trainingData.numInstances();
return tempAccuracy;
}// Of computeTrainingAccuray
/**
******************
* For integration test.
*
* @param args
* Not provided.
******************
*/
public static void main(String args[]) {
System.out.println("Starting AdaBoosting...");
Booster tempBooster = new Booster("D:/data/iris.arff");
// Booster tempBooster = new Booster("src/data/smalliris.arff");
tempBooster.setNumBaseClassifiers(100);
tempBooster.train();
System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuray());
tempBooster.test();
}// Of main
}// Of class Booster