AdaBoosting 基本思想
首先了解一下AdaBoosting 。这是一个把多个弱分类器(分类精度略高于50%)叠加成为强分类器(分类精度高于90%)的算法。
其步骤如下:
- 初始化训练数据的权值。
- 训练弱分类器。如某个训练样本能被弱分类器准确分类,那么在构造下一个训练集时,其对应的权值减小。若被错误分类,那么权重增大。权值更新过的训练集会被用于训练下一个分类器。
- 组合所有的弱分类器为一个强分类器。加大分类误差率小的弱分类器的权重,使其在最终的分类函数中起着较大的决定作用,而降低分类误差率大的弱分类器的权重,使其在最终的分类函数中起着较小的决定作用。
关于权值调整
初始权值:若共有N个训练样本,则每个训练样本的权值为 ,权值和为1.0。
调整权值步骤:
- 若第i个样本分类正确,权值调整为:。若分类错误则调整为:。其中为参数,为样本原本的权值。
- 为保证最终权值总和为1.0,第i个样本的权值应为:
输入:arff数据集
输出:初始权值;调整后的权值;保证最终权值总和为1.0的权值。
优化目标:可能没有优化目标。
代码如下:
package knn5;
import java.io.FileReader;
import java.util.Arrays;
import weka.core.*;
public class WeightedInstances extends Instances{
private static final long serialVersionUID = 11087456L;
private double[] weights;
public WeightedInstances(FileReader paraFileReader) throws Exception{
super(paraFileReader);
setClassIndex(numAttributes() - 1);
weights = new double[numInstances()];
double tempAverage = 1.0/numInstances();
for(int i = 0; i < weights.length; i++) {
weights[i] = tempAverage;
} // Of for i
System.out.println("Instances weights are: " + Arrays.toString(weights));
} // Of the first constructor
public WeightedInstances(Instances paraInstances) {
super(paraInstances);
setClassIndex(numAttributes() - 1);
weights = new double[numInstances()];
double tempAverage = 1/numInstances();
for(int i = 0; i < weights.length; i++) {
weights[i] = tempAverage;
} // Of for i
System.out.println("Instances weights are: " + Arrays.toString(weights));
} // Of the second constructor
public double getWeight(int paraIndex) {
return weights[paraIndex];
}//Of getWeight
public void adjustWeights(boolean[] paraCorrectArray, double paraAlpha) {
double tempIncrease = Math.exp(paraAlpha);
double tempWeightsSum = 0;
for(int i =0; i < weights.length; i++) {
if (paraCorrectArray[i]) {
weights[i] /= tempIncrease;
} else {
weights[i] *= tempIncrease;
} // Of if
tempWeightsSum += weights[i];
} // Of for i
for (int i = 0; i < weights.length; i++) {
weights[i] /= tempWeightsSum;
} // Of for i
System.out.println("After adjusting, instances weights are: " + Arrays.toString(weights));
} // Of adjustWeights
public void adjustWeightsTest() {
boolean[] tempCorrectArray = new boolean[numInstances()];
for (int i = 0; i < tempCorrectArray.length / 2; i++) {
tempCorrectArray[i] = true;
} // Of for i
double tempWeightedError = 0.3;
adjustWeights(tempCorrectArray, tempWeightedError);
System.out.println("After adjusting");
System.out.println(toString());
} // Of adjustWeightsTest
public String toString() {
String resultString = "I am a weighted Instances object.\r\n" + "I have " + numInstances() + " instances and "
+ (numAttributes() - 1) + " conditional attributes.\r\n" + "My weights are: " + Arrays.toString(weights)
+ "\r\n";
return resultString;
} // Of toString
public static void main(String[] args) {
WeightedInstances tempWeightedInstances = null;
String tempFilename = "C:\\\\Users\\\\ASUS\\\\Desktop\\\\文件\\\\iris.arff";
try {
FileReader tempFileReader = new FileReader(tempFilename);
tempWeightedInstances = new WeightedInstances(tempFileReader);
tempFileReader.close();
} catch (Exception exception1) {
System.out.println("Cannot read the file: " + tempFilename + "\r\n" + exception1);
System.exit(0);
} // Of try
System.out.println(tempWeightedInstances.toString());
tempWeightedInstances.adjustWeightsTest();
} // Of main
} // Of class WeightedInstances
运行截图: