机器学习入门算法及其java实现-adaboost算法

算法原理:

算法基本原理:

就一个训练样本集,求比较粗糙的分类规则(弱分类器)要比求精确的分类规则(强分类器)要容易的多,提升方法就是从弱学习算法出发,反复学习,得到一系列弱分类器(又称为基本分类器),然后组合这些弱分类器,构成一个强分类器。
AdaBoost采取加权多数表决的方法,具体地,加大分类误差率小的弱分类器的权值,使其在表决中其较大的作用,减小误差率大的弱分类器的权值,使其在表决中起较小的作用。

算法基本流程:

输入:训练数据集 T={(x1,y1),(x2,y2),...,(xN,yN)} ,其中 xiχR yiy={1,+1} ;弱学习分类算法 Gm(x) , m=1,2,...,M
输出:最终分类器 G(x)
- 初始化训练数据的权值分布:

D1=(w11,...,w1i,...,w1N),w1,i=1N,i=1,2,...,N

- 对 m=1,2,...,M
(a)、使用具有权值分布 Dm 的训练数据集学习,得到基本分类器
Gm:χ{1,+1}

(b)、计算 Gm(x) 在训练数据集上的分类误差率
em=P(Gm(xiyi))=i=1NwmiI(Gm(xi)yi)

(c)、计算 Gm(x) 的系数
α=12log1emem

这里的对数是自然对数。
(d)、更新训练数据集的权值分布
Dm+1=(wm+1,1,...,wm+1,i,...,wm+1,N)
wm+1,i=wmizmexp(αmyiGm(xi))

这里, Zm 是规范化因子
Zm=i=1Nwmiexp(αmyiGm(xi))

它使 Dm+1 成为一个概率分布。
(3)构建基本分类器的线性组合
f(x)=m+1MαmGm(x)

得到最终分类器
G(x)=sign(f(x))=sign(m=1MαmGm(x))

package adaBoost;

import java.io.IOException;

public class adaBoostmain {
    public static void main(String arg[]) throws IOException{
        InputStringData data=new InputStringData();
        String[][] mydata=data.loadData("adaboost测试数据.txt");
        for(int i=0;i<mydata.length;i++){
            for(int j=0;j<mydata[0].length;j++){
                System.out.print(mydata[i][j]+" ");
            }
            System.out.println(" ");
        }
        String[][] feature=new String[mydata[0].length][mydata.length-1];
        String[] classfied=new String[mydata[0].length];
        for(int i=0;i<feature.length;i++){
            for(int j=0;j<feature[0].length;j++){
                feature[i][j]=mydata[j][i];
                }
        }
        //读取样本属性,并存入feature
        for(int i=0;i<classfied.length;i++){
            classfied[i]=mydata[1][i];
        }
        //读取样本类别,并存入classfied
        Base base=new Base();
        String[][]Result=base.result(feature);
        for(int i=0;i<Result.length;i++){
            for(int j=0;j<Result[0].length;j++){
            }
        }
        //计算基分类器分类结果并存入数组Result

        AdaBoost myboost=new AdaBoost(Result,classfied);
        double[] alfa=myboost.alfa();
        for(int i=0;i<alfa.length;i++){
            System.out.print("第"+i+"个基分类器权重:"+alfa[i]+" ");
            System.out.println(" ");
        }
        //打印基分类器权重
    }
}

package adaBoost;

public class AdaBoost {
    String[] classfied;
    double[]weights;
    Arith ari=new Arith();
    double[]alfa;
    String[][]feature;

    public AdaBoost(String[][]a,String[]b){
        classfied=b;
        for(int i=0;i<classfied.length;i++){
        }
        feature=a;
        weights=new double[b.length];
        alfa=new double[a.length];
        for(int i=0;i<weights.length;i++){
            weights[i]=ari.div(1,b.length);
        }
    }
    //初始化权重系数


    public double[]alfa(){
        for(int i=0;i<alfa.length;i++){
            int index=0;
            double error=error(weights,feature[0]);
            for(int j=1;j<feature.length;j++){
                if(error(weights,feature[j])<error){
                    error=error(weights,feature[j]);
                    index=j;
                }
            }
            if(error!=0&&error!=1){
                alfa[index]=ari.mul(ari.div(1, 2),Math.log(ari.div(1-error,error)));
            }
            weight(alfa[index],feature[index]);
        }
        return alfa;
    }
    //计算基分类器权重

    public double error(double[]a,String[]b){
        double error=0;
        for(int i=0;i<b.length;i++){
            if(!b[i].equals(classfied[i])){
                error=error+a[i];
            }
        }
        return error;
    }
    //计算误差

    public void weight(double a,String[]b){
        double zm=0;
        for(int i=0;i<b.length;i++){
            zm=zm+weights[i]*Math.exp(ari.mul(-a,ari.mul(Double.parseDouble(b[i]),Double.parseDouble(classfied[i]))));
        }
        for(int i=0;i<weights.length;i++){
            weights[i]=ari.mul(ari.div(weights[i],zm),Math.exp(ari.mul(-a,ari.mul(Double.parseDouble(b[i]),Double.parseDouble(classfied[i])))));
        }
    }
    //计算样本权重
}


package adaBoost;

import java.math.BigDecimal;

public class Arith{
private static final int DEF_DIV_SCALE=10;

          public double add(double v1,double v2){
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              return b1.add(b2).doubleValue();
              }
          public double sub(double v1,double v2){
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              return b1.subtract(b2).doubleValue();
              }
          public double mul(double v1,double v2){
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              return b1.multiply(b2).doubleValue();
              }
          public double div(double v1,double v2){
              return div(v1,v2,DEF_DIV_SCALE);
              }
          public double div(double v1,double v2,int scale){
                  if(scale<0){
                      throw new IllegalArgumentException(
                              "The scale must be a positive integer or zero");
                      }
                  BigDecimal b1=new BigDecimal(Double.toString(v1));
                  BigDecimal b2=new BigDecimal(Double.toString(v2));
                  return b1.divide(b2,scale,BigDecimal.ROUND_HALF_UP).doubleValue();
                  }
          public double mul(double v1,double v2,int scale){
              if(scale<0){
                  throw new IllegalArgumentException(
                          "The scale must be a positive integer or zero");
                  }
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              if(v1!=0&&v2!=0){
                  BigDecimal b3=new BigDecimal(Double.toString(1));
                  BigDecimal b4=new BigDecimal(b3.divide(b2,scale,BigDecimal.ROUND_HALF_UP).doubleValue());
                  return b1.divide(b4,scale,BigDecimal.ROUND_HALF_UP).doubleValue();
              }
              else{
                  return 0;
              }
              }
              public double round(double v,int scale){
                  if(scale<0){
                      throw new IllegalArgumentException(
                              "The scale must be a positive integer or zero");
                      }
                  BigDecimal b=new BigDecimal(Double.toString(v));
                  BigDecimal one=new BigDecimal("1");
                  return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue();
                  }
              }
//因为浮点数不利于精确计算,这个包用来进行一些数值计算(此包来自于博客,经过了本人一定修改)


package adaBoost;

public class Base {
    private int m=3;
    public String[][]result(String[][]a){
        String[][]result=new String[m][a.length];
        double[][]t=new double[a.length][a[0].length];
        for(int i=0;i<t.length;i++){
            for(int j=0;j<t[0].length;j++){
                try{
                    t[i][j]=Double.parseDouble(a[i][j]);
                            }
                catch(Exception e){

                }
            }
        }
        result[0]=G1(t);
        result[1]=G2(t);
        result[2]=G3(t);
        return result;
    }
    //收集三个基分类器结果存入result
    public String[]G1(double[][]a){
        String[]G1=new String[a.length];
        for(int i=0;i<a.length;i++){
            if(a[i][0]<2.5){
                G1[i]="1";
            }
            else{
                G1[i]="-1";
            }
        }
        return G1;
    }
    //以2.5为界,大于2.5的为-1,小于2.5的为1

    public String[]G2(double[][]b){
        String[]G2=new String[b.length];
        for(int i=0;i<b.length;i++){
            if(b[i][0]>5.5){
                G2[i]="1";
            }
            else{
                G2[i]="-1";
            }
        }
        return G2;
    }
    //以5.5为界,大于5.5的为1,小于5.5的为-1

    public String[]G3(double[][]c){
        String[]G3=new String[c.length];
        for(int i=0;i<c.length;i++){
            if(c[i][0]<8.5){
                G3[i]="1";
            }
            else{
                G3[i]="-1";
            }
        }
        return G3;
    }
    //以8.5为界,小于8.5的为1,大约8.5的为-1
}

package adaBoost;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Scanner;

public class InputStringData {
           int countRow=0,countCol=0,temp=0;
    public  String[][] loadData(String trainfile)throws IOException{
           ArrayList<String>features = new ArrayList<String>();
           File file = new File("C:\\Users\\CJH\\Desktop\\R程序运行",trainfile);
           Scanner input1 = new Scanner(file);
           while(input1.hasNext()){
               String line = input1.nextLine();
               Scanner input2 = new Scanner(line);
               countRow++;input2.close();
           }
           Scanner input11 = new Scanner(file);
           int[] length=new int[countRow];
           int i=0;
           while(input11.hasNext()){
               String line = input11.nextLine();
               Scanner input2 = new Scanner(line);
               temp=0;
               while(input2.hasNext()){
               features.add(input2.next());
               temp++;
               }
               if(countCol<temp){
                   countCol=temp;
               }
               length[i]=temp;
               i++;
               input2.close();
           }
           input11.close();
           String [][]x = new String[countRow][countCol];
           int index=0;
           for(int i1=0;i1<countRow;i1++){
               for(int j=0;j<countCol;j++){
                   if(length[i1]<=j){
                       x[i1][j]="null";
                   }
                   else{
                       x[i1][j]=features.get(index);
                       index++;
                   }
               }
           }
   return x;
}
}
//输入原始数据

实验结果及实例分析

0 1 2 3 4 5 6 7 8 9
1 1 1 -1 -1 -1 1 1 1 -1
第1个基分类器权重:0.423648930186459
第2个基分类器权重:0.7520386982992482
第3个基分类器权重:0.6496414921651305

权重变化过程:
添加第0个基分类器:

样本1样本2样本3样本4样本5样本6样本7样本8样本9样本10
分类器10.10.10.10.10.10.10.10.10.1
分类器30.0710.0710.0710.0710.0710.1670.1670.1670.167
分类器20.0450.0450.0450.1670.1670.1670.1060.1060.106

设置初始权重0.1,经过计算后,第一个分类器的误差最小,6、7、8分错,误差率为0.3。经过计算 α1=0.4236
经过第二轮权重计算,增加了样本6、7、8的重要性,因此其它样本降低为0.071,6、7、8权重上升为0.167
第三个基分类器将3、4、5分错了,误差为0.2143,系数计算0.6496。因此增加3、4、5的权重。经过计算 α3=0.6496
第二个基分类器误差为0.1820,经过计算 α2=0.7514

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值