AdaBoost算法和java实现

AdaBoost算法和java实现


算法描述

输入:训练数据集这里写图片描述,其中 xi χ Rn , yi {-1,+1};弱学习算法; 
输出:最终分类器G(x)。

  1. 初始化训练集数据的权值分布 
    D1 =( w11 ,…, wiN ),  w1i =1/N, i=1,2…,N

  2. 对m=1,2,…,M

    • (a)使用具有权值分布 Dm 的训练数据集学习,得到基本分类器 
      Gm(x):χ> {-1,+1}

    • (b) 计算 Gm(x) 在训练数据集上的分类误差率 
      em= P(Gm(xi)yi)=Ni=1wmiI(Gm(xi)yi)


    • (c) 计算 Gx 的系数 
      αm=12log1emem 这里的对数是自然对数。

  • (d)更新训练数据集的权值分布 
    Dm+1=(wm+1,1,...wm+1,N)

    wm+1,i=wmiZmexp(αmyiGm(xi)),i=1,2,...,N  
    , Zm 是规范化因子

    Zm=Ni=1wmiexp(αmyiGm(xi))  
    它是 Dm+1 成为一个概率分布。


3. 构建基本分类器的线性组合

f(x)= Mm+1αmGm(x)

得到最终分类器

G(x)=sign(f(x))=sign(Mm=1αmGm(x))


举例说明

数据如下 
这里写图片描述 
当m=1时, 
根据以上的公式有 D1 =( w1i,w2i,...,w2i ), w1i=0.1,i=1,2,...,10 然后在权值分布为 D1 的训练数据集上,阈值v取2.5时分类的误差率最低,故分类器为 
注意 
在训练集上的误差率 e1 =3*0.1(3表示有三个分类错误的数据,0.1对应权值数组 D1 上的值)

按照(c)中的公式据算  α1=12log1e1e1 =0.4236

更新数据的权值分布: 
D2 =(0.07143,0.07143,0.07143,0.07143,0.07143,0.07143,0.16667,0.16667,0.16667,0.07143)()大家可以发现被错误分类的点的权值被加大了 
f1(x)=α1G1(x) =0.4236 G1(x)  
分类器sign[ f1(x) ]在训练数据集上有三个错误分类点。


当m=2时, 
-在权值分布为 D2 的训练数据集上 ,阈值v是8.5时分类误差率最低,基本分类器为

这里写图片描述 
G2(x) 在训练数据集上的误差率 e2 =0.2143 
- 计算 α2 =0.6496 
-更新训练数据集权值分布: 
D3 =(0.455,0.455,0.455,0.1667,0.1667,0.1667,0.1060,0.1060,0.1060,0.0455) 
f2(x) =0.4236 G1(x)+0.6496G2(x)  
分类器sign[ f2(x) ]在训练数据集上有三个错误分类点。


当m=3时, 
-在权值分布为 D2 的训练数据集上 ,阈值v是8.5时分类误差率最低,基本分类器为

这里写图片描述 
G2(x) 在训练数据集上的误差率 e3 =0.1820 
- 计算 α3 =0.7514 
-更新训练数据集权值分布: 
D3 =(0.125,0.125,0.125,0.102,0.102,0.102,0.065,0.065,0.065,0.125) 
f2(x) =0.4236 G1(x)+0.6496G2(x)+0.7514G3(x)  
分类器sign[ f2(x) ]在训练数据集上有0个错误分类点。 
故: G(x)=sign[f3(x)]=sign[0.4236G1(x)+0.6496G2(x)+0.7514G3(x)]

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;

public class Test08 {
    public  ArrayList<String> list=new ArrayList<String>();
    public static final double k = 0.5;
    public static void main(String[] args) {
        // TODO Auto-generated method stub
        Test08 test=new Test08();
        Map<Integer, Integer> map = new HashMap<Integer, Integer>();
        map.put(0, 1);
        map.put(1, 1);
        map.put(2, 1);
        map.put(7, 1);
        map.put(5, -1);
        map.put(6, 1);
        map.put(8, 1);
        map.put(9, -1);
        map.put(3, -1);
        map.put(4, -1);
        System.out.println(test.adaBoost(test.sortMapByKey(map)));
    }

    public TreeMap<Integer, Integer>  sortMapByKey(Map<Integer, Integer> oriMap) {
        if (oriMap == null || oriMap.isEmpty()) {
            return null;}
        TreeMap<Integer, Integer> sortedMap = new TreeMap<Integer, Integer>(new Comparator<Integer>() {
            public int compare(Integer o1, Integer o2) {
                // 如果有空值,直接返回0
                if (o1 == null || o2 == null)
                    return 0;
                return String.valueOf(o1).compareTo(String.valueOf(o2));
            }
        });
        sortedMap.putAll(oriMap);
        return sortedMap;
    }

    public Map<Double,Double> adaBoost(TreeMap<Integer, Integer> data) {
        Map<Double,Double> result=new HashMap<Double,Double>();
        int dataLenght=data.size();
        double[] weight = new double[dataLenght];
        //初始化权值数组
        for (int i = 0; i < dataLenght; i++) {
            weight[i]=1.0/dataLenght;
        }

        double grade1 = 0;
        double grade2 = 0;
        //double flag = 0;
        String f=null;
        double current=0;
        double ah=0;
        double low=data.firstKey();//选取最小的特征值
        double high=data.lastKey();//选取最大的特征值
        //迭代50次
        for(int it=0;it<50;it++){
            double min=1000;
            double flag=low;//用来标记比较优的特征的值
            while(flag<=high){
                int index = 0;// 用来索引权值数组
                grade1=0;
                grade2=0;
                for (Integer en : data.keySet()) {
                    //大于某一个特征值则为一时
                    if(GreatToOne(en, flag)!=data.get(en)){
                        grade1+=weight[index];
                    }
                    //小于某一个特征值则为一时
                    if(LessToOne(en, flag)!=data.get(en)){
                        grade2+=weight[index];
                    }   
                    index++;
                }
                //选取最优的特征值
                if (grade1 < min) {
                    min = grade1;
                    current = flag;
                    f="great";//用来标记采用的哪一个函数(GreatToOne or LessToOne)
                }
                if(grade2<min){
                    min=grade2;
                    current = flag;
                    f="less";
                }
                flag+=k;//将用来分类的特征值增加k
            }
            ah=0.5*Math.log((1-min)/min);
            double totle=0;
            int j=0;
            //
            for(Integer en:data.keySet()){
                if(f.equals("great")){
                    totle+=weight[j++]*Math.exp(-ah*data.get(en)*GreatToOne(en,current));
                }
                else{
                    totle+=weight[j++]*Math.exp(-ah*data.get(en)*LessToOne(en,current));
                }

            }
            j=0;
            for(Integer en:data.keySet()){
                if(f.equals("great")){
                    weight[j]=weight[j]*Math.exp(-ah*data.get(en)*GreatToOne(en,current))/totle;
                }
                else{
                    weight[j]=weight[j]*Math.exp(-ah*data.get(en)*LessToOne(en,current))/totle;
                }
                j++;
            }
        result.put(ah, current);
        list.add(f);
        //错误率为零,则退出
        if(calc(result,data)==0) break;
        }
    return result;
    }
    private int calc(Map<Double, Double> result, TreeMap<Integer, Integer> data) {
        // TODO Auto-generated method stub
        int count=0;

        for(Integer en:data.keySet()){
            double sum=0;int index=0;
            for(Double d:result.keySet()){
                if(list.get(index).equals("great")){
                    sum+=d*GreatToOne(en,result.get(d));
                }
                else{
                    sum+=d*LessToOne(en,result.get(d));
                }   
                index++;            
            }
            if(sum>0&&data.get(en)==-1) {
                count++;
            }
            if(sum<0&&data.get(en)==1){
                count++;
            }

        }
        if(count==0){
            return 0;
            }
        else{
            return 1;
        }
    }

    public int GreatToOne(int x,double flag){
        if(x>flag) {
            return 1;
        }else{
            return -1;
        }
    }
    public int LessToOne(int x,double flag){
        if(x<flag) {
            return 1;
        }else{
            return -1;
        }
    }
}


结果如下: 

这里写图片描述


好的,以下是使用Java实现AdaBoost算法代码,用于鸢尾花分类: ```java import java.util.ArrayList; import java.util.Collections; public class AdaBoost { // 训练数据集 private ArrayList<ArrayList<Double>> dataSet; // 类别标签 private ArrayList<Integer> labels; // 弱分类器数目 private int weakNum; // 训练好的弱分类器集合 private ArrayList<WeakClassifier> weakClassifiers; // AdaBoost构造函数 public AdaBoost(ArrayList<ArrayList<Double>> dataSet, ArrayList<Integer> labels, int weakNum) { this.dataSet = dataSet; this.labels = labels; this.weakNum = weakNum; this.weakClassifiers = new ArrayList<>(); } // 训练分类器 public void train() { int size = dataSet.size(); // 初始化权重向量 ArrayList<Double> weights = new ArrayList<>(); for (int i = 0; i < size; i++) { weights.add(1.0 / size); } // 训练 weakNum 个弱分类器 for (int i = 0; i < weakNum; i++) { // 训练单个弱分类器 WeakClassifier weakClassifier = new WeakClassifier(dataSet, labels, weights); weakClassifier.train(); // 计算错误率 double error = 0.0; for (int j = 0; j < size; j++) { if (weakClassifier.predict(dataSet.get(j)) != labels.get(j)) { error += weights.get(j); } } // 计算弱分类器权重 double alpha = 0.5 * Math.log((1 - error) / error); weakClassifier.setAlpha(alpha); // 更新权重向量 for (int j = 0; j < size; j++) { if (weakClassifier.predict(dataSet.get(j)) == labels.get(j)) { weights.set(j, weights.get(j) * Math.exp(-alpha)); } else { weights.set(j, weights.get(j) * Math.exp(alpha)); } } // 归一化权重向量 double sum = 0.0; for (int j = 0; j < size; j++) { sum += weights.get(j); } for (int j = 0; j < size; j++) { weights.set(j, weights.get(j) / sum); } // 将训练好的弱分类器加入集合 weakClassifiers.add(weakClassifier); } } // 预测分类结果 public int predict(ArrayList<Double> data) { double sum = 0.0; for (WeakClassifier wc : weakClassifiers) { sum += wc.predict(data) * wc.getAlpha(); } if (sum > 0) { return 1; } else { return -1; } } // 测试分类器 public void test(ArrayList<ArrayList<Double>> testData, ArrayList<Integer> testLabels) { int errorNum = 0; int size = testData.size(); for (int i = 0; i < size; i++) { if (predict(testData.get(i)) != testLabels.get(i)) { errorNum++; } } double accuracy = 1 - (double) errorNum / size; System.out.println("Accuracy: " + accuracy); } // 主函数 public static void main(String[] args) { // 读取数据集 ArrayList<ArrayList<Double>> dataSet = Util.loadDataSet("iris.data"); // 打乱数据集顺序 Collections.shuffle(dataSet); // 获取标签 ArrayList<Integer> labels = new ArrayList<>(); for (ArrayList<Double> data : dataSet) { if (data.get(data.size() - 1) == 1) { labels.add(1); } else { labels.add(-1); } } // 划分训练集和测试集 ArrayList<ArrayList<Double>> trainData = new ArrayList<>(); ArrayList<ArrayList<Double>> testData = new ArrayList<>(); ArrayList<Integer> trainLabels = new ArrayList<>(); ArrayList<Integer> testLabels = new ArrayList<>(); for (int i = 0; i < dataSet.size(); i++) { if (i % 5 == 0) { testData.add(dataSet.get(i)); testLabels.add(labels.get(i)); } else { trainData.add(dataSet.get(i)); trainLabels.add(labels.get(i)); } } // 训练 AdaBoost 分类器 AdaBoost adaBoost = new AdaBoost(trainData, trainLabels, 10); adaBoost.train(); // 测试分类器 adaBoost.test(testData, testLabels); } } ``` 需要注意的是,此代码中的 `WeakClassifier` 类是用于实现单个弱分类器的训练和预测的,需要自行实现。同时,数据集的加载和处理部分也需要根据实际情况进行修改。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值