集成学习之Adaboost
时间:2022/6/13
文章目录
1.集成学习
集成学习是一种提升分类器的性能的方法。通过整合多个 基学习器(Base learner来完成学习任务。集成学习被认为是一种 元算法
1.1学习器
-
强学习器(Strong learner):相对于弱学习器而言,强学习器指的是可以预测相当精准的学习器。
-
弱学习器(Weak learner):相对于强学习器而言,这类学习器的效果通常只比随机结果要好一点。
-
基学习器(Base learner):是集成学习中,每个单独的学习器即为基学习器。通常采用弱学习器,但不一定必须是弱学习器。
-
基学习算法(Base Learning Algorithm):基学习器所基于的算法,通过该算法生成相应的基学习器。
-
同质基学习器(Homogeneous Base Learner):采用相同的基学习算法生成的基学习器。
-
异质基学习器(Heterogeneous Base Learner):采用不同基学习算法生成的基学习器。在异质方法中,通常称为**组件学习器(component learner)**或者是叫做个体学习器
1.2集成学习算法
通常来说,生成一个完整的集成学习算法的步骤可以大致分为两步:
-
构建基学习器:生成一系列基学习器,这个过程可以是并行的,也可以是串行的。在并行生成时,相互之间的基学习器是相互独立的,而在串行的生成过程中,前期生成的基学习器会对后续生成的学习器有影响。
-
组合基学习器:这些基学习器被组合起来使用,最常见的组合方法比如用于分类的多数投票(majority voting),以及用于回归的权重平均(weighted averaging)
集成学习的构建方法主要分为两类:
-
并行化方法。
-
构建多个独立的学习器,取预测结果平均值。
-
个体学习器之间不存在强依赖关系,一个系列个体学习器可以并行生成
-
通常是同质的弱学习器
-
代表算法是Bagging和随机森林(Random Forest系列算法
-
-
序列化方法
-
多个学习器是依次构建的
-
个体学习器之间存在强依赖关系,因为一系列个体学习器需要串行生成。
-
通常是异质的学习器。
-
代表算法是Boosting系列算法,比如AdaBoost,梯度提升树等
-
1.3学习器的选择
考虑二分类问题 y ∈ { − 1 , + 1 } y\in\{-1,+1\} y∈{−1,+1}和真实函数 f ( x ) f(x) f(x)假设分类器的错误率为 ϵ \epsilon ϵ,即对于每个基分类器 h i h_i hi有:
P ( h i ( x ) ≠ f ( x ) ) = ϵ (1) P(h_i(x)\ne f(x))=\epsilon\tag{1} P(hi(x)=f(x))=ϵ(1)
假设集成器通过简单投票法结合所有 T T T个基学习器,若超过半数的基分类器正确,则集成分类正确。表示为:
H ( x ) = s i g n ( ∑ i = 1 T h i ( x ) ) (2) H(x)=sign(\sum_{i=1}^Th_i(x))\tag{2} H(x)=sign(i=1∑Thi(x))(2)
假设基分类器的错误率相互独立,则由Hoeffding不等式可得,集成器的错误率为:
P ( H ( x ) ≠ f ( x ) ) = ∑ k = 0 ⌊ T / 2 ⌋ ( T k ) ( 1 − ϵ ) k ϵ T − k ≤ e x p ( − 1 2 T ( 1 − 2 ϵ ) 2 ) (3) P(H(x)\ne f(x))=\sum_{k=0}^{\lfloor T/2 \rfloor}{{T \choose k}(1-\epsilon)^k\epsilon^{T-k}\le exp(-{1\over 2}T(1-2\epsilon)^2})\tag{3} P(H(x)=f(x))=k=0∑⌊T/2⌋(kT)(1−ϵ)kϵT−k≤exp(−21T(1−2ϵ)2)(3)
上式表明,当基分类器的数量T增大时,集成的错误率将指数级下降,最终趋于0。
但实际上基分类器的错误率实际上是不可能相互独立的。因为个体学习器都是为了解决同一问题而提出的。故学习器的选择,要尽量满足“好而不同”的特点。如何选择“好而不同”的学习器,便是集成学习研究的重点之一。
2.AdaBoost
2.1Boosting
Boosting的主要机制:先从初始训练集中学习出一个基学习器,再根据基学习器的表现对训练样本分布进行调整,使得先前基学习器分类错误的训练样本在后续受到更多关注,然后基于调整之后的样本分布来训练下一个基学习器。如此重复进行,直至基学习器数目达到事先指定的值 T T T.最终将这T个基学习器进行加权结合。
2.2AdaBoosting
AdaBoosting算法思想如上图所示。数据集初始权值为1/N,N为数据集中实例个数。保证权值的归一化。通过初始权值和数据可以通过基学习器产生一次投票权值,通过前一次分类器分类的错误率可以推出下一次分类器的权值。同时分类器分类后更新数据集的权值,作为下一个分类器的输入数据集。不断迭代,直至最后通过投票的方式,将所有基学习器的权值进行整合。得到分类结果。
分类器权值更新:
a t = 1 2 l n ( 1 − ϵ t ϵ t ) (4) a_t={1\over 2}ln({1-\epsilon_t \over \epsilon_t})\tag{4} at=21ln(ϵt1−ϵt)(4)
上式即为第t个基分类器的权值,该式是基于最小化指数损失函数推导而来。具体推导可见西瓜书。由该式可知,分类器分类错误率越大,权值越小
则基分类器的线性组合为:
H ( x ) = ∑ t = 1 T a t h t ( x ) (5) H(x)=\sum_{t=1}^Ta_th_t(x)\tag{5} H(x)=t=1∑Tatht(x)(5)
则最终的分类器为:
G ( x ) = s i g n ( H ( x ) ) = s i g n ( ∑ t = 1 T a t h t ( x ) ) (6) G(x)=sign(H(x))=sign(\sum_{t=1}^Ta_th_t(x))\tag{6} G(x)=sign(H(x))=sign(t=1∑Tatht(x))(6)
3.算法实现
3.1带权数据集
weka包中没有我们需要的带权数据集,故重写一个带权数据集用以存储数据和相应的权值。以及每次分类过后调整数据集权值的方法。
/**
* WeightedInstances.java
*
* @author zjy
* @date 2022/6/12
* @Description: 带权数据集
* @version V1.0
*/
package swpu.zjy.ML.AdaBoosting.myboost;
import weka.core.Instances;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.Arrays;
public class WeightedInstances extends Instances {
//序列化
private static final long serialVersionUID = 11087456L;
//权值
public double[] weights;
/**
* 构造方法,构造一个带权值的数据集
*
* @param reader 输入流
* @throws IOException
*/
public WeightedInstances(Reader reader) throws IOException {
super(reader);
setClassIndex(numAttributes() - 1);
weights = new double[numInstances()];
//初始权值定义为1/numInstances()
double tempAverage = 1.0 / numInstances();
for (int i = 0; i < numInstances(); i++) {
weights[i] = tempAverage;
}
System.out.println("Instances weights are: " + Arrays.toString(weights));
}
/**
* 构造方法,使用示例构造
*
* @param paraInstance 示例
*/
public WeightedInstances(Instances paraInstance) {
super(paraInstance);
setClassIndex(numAttributes() - 1);
// Initialize weights
weights = new double[numInstances()];
double tempAverage = 1.0 / numInstances();
for (int i = 0; i < weights.length; i++) {
weights[i] = tempAverage;
} // Of for i
System.out.println("Instances weights are: " + Arrays.toString(weights));
}
/**
* 获取权值
*
* @return 权值
*/
public double getWeights(int i) {
return weights[i];
}
/**
* 数据集权值调整,根据基分类器分类结果对权值进行调整。
* 分类正确,减小权值,分类错误,增大权值
*
* @param paraCorrectArray 分类结果数组,正确:true;错误:false;
* @param paraAlpha 分类器权值
*/
public void adjustWeights(boolean[] paraCorrectArray, double paraAlpha) {
//获取Alpha
double tempIncrease = Math.exp(paraAlpha);
//根据分类结果调整权值
double tempWeightsSum = 0;
for (int i = 0; i < weights.length; i++) {
if (paraCorrectArray[i]) {
weights[i] /= tempIncrease;
} else {
weights[i] *= tempIncrease;
}
tempWeightsSum += weights[i];
}
//归一化
for (int i = 0; i < weights.length; i++) {
weights[i] /= tempWeightsSum;
}
System.out.println("After adjusting, instances weights are: " + Arrays.toString(weights));
}
/**
* 方便展示,重写toString()
*
* @return
*/
public String toString() {
String resultString = "I am a weighted Instances object.\r\n" + "I have " + numInstances() + " instances and "
+ (numAttributes() - 1) + " conditional attributes.\r\n" + "My weights are: " + Arrays.toString(weights)
+ "\r\n" + "My data are: \r\n" + super.toString();
return resultString;
}
public static void main(String args[]) {
WeightedInstances tempWeightedInstances = null;
String tempFilename = "E:\\DataSet\\iris.arff";
try {
FileReader tempFileReader = new FileReader(tempFilename);
tempWeightedInstances = new WeightedInstances(tempFileReader);
tempFileReader.close();
} catch (Exception exception1) {
System.out.println("Cannot read the file: " + tempFilename + "\r\n" + exception1);
System.exit(0);
}
System.out.println(tempWeightedInstances.toString());
boolean[] tempCorrectArray = new boolean[tempWeightedInstances.numInstances()];
for (int i = 0; i < tempCorrectArray.length / 2; i++) {
tempCorrectArray[i] = true;
}
double tempWeightedError = 0.3;
tempWeightedInstances.adjustWeights(tempCorrectArray, tempWeightedError);
System.out.println("After adjusting");
System.out.println(tempWeightedInstances.toString());
}
}
3.2SimpleClassifierAbstract
对基分类器的抽象,抽象训练方法,分类方法。分类结果计算,分类错误率与正确率计算等。
/**
* SimpleClassifierAbstract.java
*
* @author zjy
* @date 2022/6/12
* @Description: 基分类器抽象类
* @version V1.0
*/
package swpu.zjy.ML.AdaBoosting.myboost;
import weka.core.Instance;
import java.util.Random;
public abstract class SimpleClassifierAbstract {
//带权数据集存储
public WeightedInstances weightedInstances;
//分裂属性
int selectedAttribute;
//训练准确率
double trainingAccuracy;
//类取值个数
int numClasses;
//实例数
int numInstances;
//条件属性
int numConditions;
//随机数
Random random=new Random();
/**
* 构造方法,初始化分类器
* @param paraWeightedInstance 带权数据集
*/
public SimpleClassifierAbstract(WeightedInstances paraWeightedInstance){
weightedInstances=paraWeightedInstance;
numClasses=weightedInstances.classAttribute().numValues();
numInstances=weightedInstances.numInstances();
numConditions=weightedInstances.numAttributes() - 1;
}
/**
* 抽象方法训练
*/
public abstract void train();
/**
* 分类方法,对传入实例进行分类
* @param paraInstance 需要分类的实例
* @return 分类类标号
*/
public abstract int classify(Instance paraInstance);
/**
* 统计分类正确
* @return 分类正确结果
*/
public boolean[] computeCorrectnessArray(){
boolean[] resultCorrect=new boolean[numInstances];
for (int i = 0; i < numInstances; i++) {
Instance tempInstance=weightedInstances.instance(i);
if(tempInstance.classValue()==classify(tempInstance)){
resultCorrect[i]=true;
}
}
return resultCorrect;
}
/**
* 计算训练分类准确率
* @return 分类准确率
*/
public double computeTrainingAccuracy(){
int tempNumCorrect=0;
boolean[] tempCorrectArray=computeCorrectnessArray();
for (int i = 0; i < tempCorrectArray.length; i++) {
if(tempCorrectArray[i]){
tempNumCorrect++;
}
}
return 1.0*tempNumCorrect/tempCorrectArray.length;
}
/**
* 计算分类错误率
* @return 分类错误
*/
public double computeWeightedError() {
double resultError = 0;
boolean[] tempCorrectnessArray = computeCorrectnessArray();
for (int i = 0; i < tempCorrectnessArray.length; i++) {
if (!tempCorrectnessArray[i]) {
resultError += weightedInstances.getWeights(i);
}
}
if (resultError < 1e-6) {
resultError = 1e-6;
}
return resultError;
}
}
3.3树桩分类器
树桩分类器是高度为一的决策树。通过简单算法获取最佳分类点,处理连续性数据。
/**
* StumpClassifier.java
*
* @author zjy
* @date 2022/6/12
* @Description: 树桩分类器,高度为一的决策树
* @version V1.0
*/
package swpu.zjy.ML.AdaBoosting.myboost;
import weka.core.Instance;
import java.io.FileReader;
import java.util.Arrays;
public class StumpClassifier extends SimpleClassifierAbstract{
double bestCut;
int leftLeafLabel;
int rightLeafLabel;
/**
* 构造方法,初始化分类器
*
* @param paraWeightedInstance 带权数据集
*/
public StumpClassifier(WeightedInstances paraWeightedInstance) {
super(paraWeightedInstance);
}
@Override
public void train() {
//step1.选择分裂属性
selectedAttribute=random.nextInt(numConditions);
//step2. 读入分裂属性数据
double[] tempValues=new double[numInstances];
for (int i = 0; i < tempValues.length; i++) {
tempValues[i]=weightedInstances.instance(i).value(selectedAttribute);
}
Arrays.sort(tempValues);
//step3.统计当前分裂点分裂结果标签
int tempNumClasses=numClasses;
double[] tempClassCountArray=new double[tempNumClasses];
int tempCurrentClassValue;
for (int i = 0; i < numInstances; i++) {
tempCurrentClassValue=(int)weightedInstances.instance(i).classValue();
tempClassCountArray[tempCurrentClassValue]+=weightedInstances.getWeights(i);
}
//找寻最多的标签
double tempMaxCorrect = 0;
int tempBestClass=-1;
for (int i = 0; i < tempClassCountArray.length; i++) {
if(tempMaxCorrect<tempClassCountArray[i]){
tempMaxCorrect=tempClassCountArray[i];
tempBestClass=i;
}
}
bestCut=tempValues[0]-0.1;
leftLeafLabel=tempBestClass;
rightLeafLabel=tempBestClass;
double tempCut;
//用于统计类标号,第一维是树桩分支数量,第二维是类标号数量;
double[][] tempClassCountMatrix=new double[2][tempNumClasses];
//找寻最佳分裂点
for (int i = 0; i < tempValues.length-1; i++) {
//若相邻的数据一致则跳过本次
if(tempValues[i]==tempValues[i+1]){
continue;
}
//计算分裂点
tempCut=(tempValues[i]+tempValues[i+1])/2;
//初始化矩阵
for (int j = 0; j < 2; j++) {
for (int k = 0; k < numClasses; k++) {
tempClassCountMatrix[j][k]=0;
}
}
//统计当前分裂点的分类结果
for (int j = 0; j < numInstances; j++) {
tempCurrentClassValue=(int)weightedInstances.instance(j).classValue();
if(weightedInstances.instance(j).value(selectedAttribute)<tempCut){
tempClassCountMatrix[0][tempCurrentClassValue]+=weightedInstances.getWeights(j);
}else {
tempClassCountMatrix[1][tempCurrentClassValue]+=weightedInstances.getWeights(j);
}
}
//统计左子树分类结果
double tempLeftMaxCorrect = 0;
int tempLeftBestLabel = 0;
for (int j = 0; j < numClasses; j++) {
if(tempClassCountMatrix[0][j]>tempLeftMaxCorrect){
tempLeftMaxCorrect=tempClassCountMatrix[0][j];
tempLeftBestLabel=j;
}
}
//统计右子树分类结果
double tempRightMaxCorrect = 0;
int tempRightBestLabel = 0;
for (int j = 0; j < numClasses; j++) {
if(tempClassCountMatrix[1][j]>tempRightMaxCorrect){
tempRightMaxCorrect=tempClassCountMatrix[1][j];
tempRightBestLabel=j;
}
}
//更新分裂点,左右子树类标号
if(tempMaxCorrect<tempLeftMaxCorrect+tempRightMaxCorrect){
tempMaxCorrect=tempLeftMaxCorrect+tempRightMaxCorrect;
bestCut=tempCut;
leftLeafLabel=tempLeftBestLabel;
rightLeafLabel=tempRightBestLabel;
}
}
System.out.println("Attribute = " + selectedAttribute + ", cut = " + bestCut + ", leftLeafLabel = "
+ leftLeafLabel + ", rightLeafLabel = " + rightLeafLabel);
}
/**
* 使用树桩分类器进行分类
* @param paraInstance 需要分类的实例
* @return
*/
@Override
public int classify(Instance paraInstance) {
int resultLabel = -1;
if (paraInstance.value(selectedAttribute) < bestCut) {
resultLabel = leftLeafLabel;
} else {
resultLabel = rightLeafLabel;
}
return resultLabel;
}
public String toString() {
String resultString = "I am a stump classifier.\r\n" + "I choose attribute #" + selectedAttribute
+ " with cut value " + bestCut + ".\r\n" + "The left and right leaf labels are " + leftLeafLabel
+ " and " + rightLeafLabel + ", respectively.\r\n" + "My weighted error is: " + computeWeightedError()
+ ".\r\n" + "My weighted accuracy is : " + computeTrainingAccuracy() + ".";
return resultString;
}
public static void main(String args[]) {
WeightedInstances tempWeightedInstances = null;
String tempFilename = "E:\\DataSet\\iris.arff";
try {
FileReader tempFileReader = new FileReader(tempFilename);
tempWeightedInstances = new WeightedInstances(tempFileReader);
tempFileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + tempFilename + "\r\n" + ee);
System.exit(0);
} // Of try
StumpClassifier tempClassifier = new StumpClassifier(tempWeightedInstances);
tempClassifier.train();
System.out.println(tempClassifier);
System.out.println(Arrays.toString(tempClassifier.computeCorrectnessArray()));
}// Of main
}
3.4集成器
集成器,使用简单投票作为结合方法。
/**
* Booster.java
*
* @author zjy
* @date 2022/6/13
* @Description: 集成器
* @version V1.0
*/
package swpu.zjy.ML.AdaBoosting.myboost;
import weka.core.Instance;
import weka.core.Instances;
import java.io.FileReader;
public class Booster {
//存储基分类器
SimpleClassifierAbstract[] classifiers;
//统计训练使用的分类器数量
int numClassifier;
//
boolean stopAfterConverge=false;
//记录基分类器权值
double[] classifierWeights;
//训练数据
Instances trainingData;
//测试数据
Instances testingData;
/**
* 构造方法,初始化集成器
* @param datasetFileName 数据集地址
*/
public Booster(String datasetFileName){
try {
FileReader fileReader=new FileReader(datasetFileName);
trainingData=new Instances(fileReader);
fileReader.close();
} catch (Exception e) {
e.printStackTrace();
}
trainingData.setClassIndex(trainingData.numAttributes()-1);
testingData=trainingData;
stopAfterConverge=true;
}
/**
* 设置基分类器数量,初始化基分类器相关参数
* @param numClassifier 基分类器数量
*/
public void setNumClassifier(int numClassifier) {
this.numClassifier = numClassifier;
classifiers=new SimpleClassifierAbstract[numClassifier];
classifierWeights=new double[numClassifier];
}
/**
* 集成器训练过程
*/
public void train(){
WeightedInstances weightedInstances=null;
double tempError;
numClassifier=0;
//构建基分类器
for (int i = 0; i < classifiers.length; i++) {
//step1.添加数据集
if(i==0){
//第一个基分类器初始化带权数据集
weightedInstances=new WeightedInstances(trainingData);
}else {
//其他基分类器使用上一个基分类器调整权值之后的数据集
weightedInstances.adjustWeights(classifiers[i-1].computeCorrectnessArray(),classifierWeights[i-1]);
}
//step2.训练当前分类器
classifiers[i]=new StumpClassifier(weightedInstances);
classifiers[i].train();
tempError = classifiers[i].computeWeightedError();
/**
* 核心,设置基分类器的权值
* 权值的确定,推导复杂,我也没有搞懂
*/
classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
//太小的权值修正为0
if (classifierWeights[i] < 1e-6) {
classifierWeights[i] = 0;
}
numClassifier++;
//判断分类准确率是否达到预期
if(stopAfterConverge){
double tempTrainingAccuracy=computeTrainingAccuracy();
System.out.println("The accuracy of the "+i+"th Booster is: " + tempTrainingAccuracy + "\r\n");
if(tempTrainingAccuracy>0.999999){
System.out.println("Stop at the round: " + i + " due to converge.\r\n");
break;
}
}
}
}
/**
* 使用集成器进行分类
* @param paraInstance 测试实例
* @return 预测标签
*/
public int classify(Instance paraInstance){
double[] tempLabelCountArray=new double[trainingData.classAttribute().numValues()];
//计算权值
for (int i = 0; i < numClassifier; i++) {
int tempLabel=classifiers[i].classify(paraInstance);
tempLabelCountArray[tempLabel]+=classifierWeights[i];
}
//分类器投票
int resultLabel=-1;
double tempMax=-1;
for (int i = 0; i < tempLabelCountArray.length; i++) {
if(tempMax<tempLabelCountArray[i]){
tempMax=tempLabelCountArray[i];
resultLabel=i;
}
}
return resultLabel;
}
/**
* 计算分类器的训练准确率
* @return 当前分类器的分类准确率
*/
public double computeTrainingAccuracy(){
double tempCorrect=0;
for (int i = 0; i < trainingData.numInstances(); i++) {
if(classify(trainingData.instance(i))==trainingData.instance(i).classValue()){
tempCorrect++;
}
}
return tempCorrect/trainingData.numInstances();
}
/**
* 测试入口
* @return 测试准确率
*/
public double test(){
System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");
return test(testingData);
}
/**
* 对指定数据集进行分类预测
* @param paraInstances 测试集
* @return 分类准确率
*/
public double test(Instances paraInstances) {
double tempCorrect = 0;
paraInstances.setClassIndex(paraInstances.numAttributes() - 1);
for (int i = 0; i < paraInstances.numInstances(); i++) {
Instance tempInstance = paraInstances.instance(i);
if (classify(tempInstance) == (int) tempInstance.classValue()) {
tempCorrect++;
}
}
double resultAccuracy = tempCorrect / paraInstances.numInstances();
System.out.println("The accuracy is: " + resultAccuracy);
return resultAccuracy;
}
public static void main(String[] args) {
System.out.println("Starting AdaBoosting!");
Booster booster=new Booster("E:\\DataSet\\iris.arff");
booster.setNumClassifier(100);
booster.train();
System.out.println("The training accuracy is: " + booster.computeTrainingAccuracy());
booster.train();
}
}
pInstance = paraInstances.instance(i);
if (classify(tempInstance) == (int) tempInstance.classValue()) {
tempCorrect++;
}
}
double resultAccuracy = tempCorrect / paraInstances.numInstances();
System.out.println("The accuracy is: " + resultAccuracy);
return resultAccuracy;
}
public static void main(String[] args) {
System.out.println("Starting AdaBoosting!");
Booster booster=new Booster("E:\\DataSet\\iris.arff");
booster.setNumClassifier(100);
booster.train();
System.out.println("The training accuracy is: " + booster.computeTrainingAccuracy());
booster.train();
}
}