机器学习入门算法及其java实现-adaboost算法

算法原理:

算法基本原理:

就一个训练样本集,求比较粗糙的分类规则(弱分类器)要比求精确的分类规则(强分类器)要容易的多,提升方法就是从弱学习算法出发,反复学习,得到一系列弱分类器(又称为基本分类器),然后组合这些弱分类器,构成一个强分类器。
AdaBoost采取加权多数表决的方法,具体地,加大分类误差率小的弱分类器的权值,使其在表决中其较大的作用,减小误差率大的弱分类器的权值,使其在表决中起较小的作用。

算法基本流程:

输入:训练数据集T={(x1,y1),(x2,y2),...,(xN,yN)},其中xiχRyiy={1,+1};弱学习分类算法Gm(x),m=1,2,...,M
输出:最终分类器G(x)
- 初始化训练数据的权值分布:

D1=(w11,...,w1i,...,w1N),w1,i=1N,i=1,2,...,N

- 对 m=1,2,...,M
(a)、使用具有权值分布Dm的训练数据集学习,得到基本分类器
Gm:χ{1,+1}

(b)、计算Gm(x)在训练数据集上的分类误差率
em=P(Gm(xiyi))=i=1NwmiI(Gm(xi)yi)

(c)、计算Gm(x)的系数
α=12log1emem

这里的对数是自然对数。
(d)、更新训练数据集的权值分布
Dm+1=(wm+1,1,...,wm+1,i,...,wm+1,N)
wm+1,i=wmizmexp(αmyiGm(xi))

这里,Zm是规范化因子
Zm=i=1Nwmiexp(αmyiGm(xi))

它使Dm+1成为一个概率分布。
(3)构建基本分类器的线性组合
f(x)=m+1MαmGm(x)

得到最终分类器
G(x)=sign(f(x))=sign(m=1MαmGm(x))

package adaBoost;

import java.io.IOException;

public class adaBoostmain {
    public static void main(String arg[]) throws IOException{
        InputStringData data=new InputStringData();
        String[][] mydata=data.loadData("adaboost测试数据.txt");
        for(int i=0;i<mydata.length;i++){
            for(int j=0;j<mydata[0].length;j++){
                System.out.print(mydata[i][j]+" ");
            }
            System.out.println(" ");
        }
        String[][] feature=new String[mydata[0].length][mydata.length-1];
        String[] classfied=new String[mydata[0].length];
        for(int i=0;i<feature.length;i++){
            for(int j=0;j<feature[0].length;j++){
                feature[i][j]=mydata[j][i];
                }
        }
        //读取样本属性,并存入feature
        for(int i=0;i<classfied.length;i++){
            classfied[i]=mydata[1][i];
        }
        //读取样本类别,并存入classfied
        Base base=new Base();
        String[][]Result=base.result(feature);
        for(int i=0;i<Result.length;i++){
            for(int j=0;j<Result[0].length;j++){
            }
        }
        //计算基分类器分类结果并存入数组Result

        AdaBoost myboost=new AdaBoost(Result,classfied);
        double[] alfa=myboost.alfa();
        for(int i=0;i<alfa.length;i++){
            System.out.print("第"+i+"个基分类器权重:"+alfa[i]+" ");
            System.out.println(" ");
        }
        //打印基分类器权重
    }
}

package adaBoost;

public class AdaBoost {
    String[] classfied;
    double[]weights;
    Arith ari=new Arith();
    double[]alfa;
    String[][]feature;

    public AdaBoost(String[][]a,String[]b){
        classfied=b;
        for(int i=0;i<classfied.length;i++){
        }
        feature=a;
        weights=new double[b.length];
        alfa=new double[a.length];
        for(int i=0;i<weights.length;i++){
            weights[i]=ari.div(1,b.length);
        }
    }
    //初始化权重系数


    public double[]alfa(){
        for(int i=0;i<alfa.length;i++){
            int index=0;
            double error=error(weights,feature[0]);
            for(int j=1;j<feature.length;j++){
                if(error(weights,feature[j])<error){
                    error=error(weights,feature[j]);
                    index=j;
                }
            }
            if(error!=0&&error!=1){
                alfa[index]=ari.mul(ari.div(1, 2),Math.log(ari.div(1-error,error)));
            }
            weight(alfa[index],feature[index]);
        }
        return alfa;
    }
    //计算基分类器权重

    public double error(double[]a,String[]b){
        double error=0;
        for(int i=0;i<b.length;i++){
            if(!b[i].equals(classfied[i])){
                error=error+a[i];
            }
        }
        return error;
    }
    //计算误差

    public void weight(double a,String[]b){
        double zm=0;
        for(int i=0;i<b.length;i++){
            zm=zm+weights[i]*Math.exp(ari.mul(-a,ari.mul(Double.parseDouble(b[i]),Double.parseDouble(classfied[i]))));
        }
        for(int i=0;i<weights.length;i++){
            weights[i]=ari.mul(ari.div(weights[i],zm),Math.exp(ari.mul(-a,ari.mul(Double.parseDouble(b[i]),Double.parseDouble(classfied[i])))));
        }
    }
    //计算样本权重
}


package adaBoost;

import java.math.BigDecimal;

public class Arith{
private static final int DEF_DIV_SCALE=10;

          public double add(double v1,double v2){
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              return b1.add(b2).doubleValue();
              }
          public double sub(double v1,double v2){
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              return b1.subtract(b2).doubleValue();
              }
          public double mul(double v1,double v2){
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              return b1.multiply(b2).doubleValue();
              }
          public double div(double v1,double v2){
              return div(v1,v2,DEF_DIV_SCALE);
              }
          public double div(double v1,double v2,int scale){
                  if(scale<0){
                      throw new IllegalArgumentException(
                              "The scale must be a positive integer or zero");
                      }
                  BigDecimal b1=new BigDecimal(Double.toString(v1));
                  BigDecimal b2=new BigDecimal(Double.toString(v2));
                  return b1.divide(b2,scale,BigDecimal.ROUND_HALF_UP).doubleValue();
                  }
          public double mul(double v1,double v2,int scale){
              if(scale<0){
                  throw new IllegalArgumentException(
                          "The scale must be a positive integer or zero");
                  }
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              if(v1!=0&&v2!=0){
                  BigDecimal b3=new BigDecimal(Double.toString(1));
                  BigDecimal b4=new BigDecimal(b3.divide(b2,scale,BigDecimal.ROUND_HALF_UP).doubleValue());
                  return b1.divide(b4,scale,BigDecimal.ROUND_HALF_UP).doubleValue();
              }
              else{
                  return 0;
              }
              }
              public double round(double v,int scale){
                  if(scale<0){
                      throw new IllegalArgumentException(
                              "The scale must be a positive integer or zero");
                      }
                  BigDecimal b=new BigDecimal(Double.toString(v));
                  BigDecimal one=new BigDecimal("1");
                  return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue();
                  }
              }
//因为浮点数不利于精确计算,这个包用来进行一些数值计算(此包来自于博客,经过了本人一定修改)


package adaBoost;

public class Base {
    private int m=3;
    public String[][]result(String[][]a){
        String[][]result=new String[m][a.length];
        double[][]t=new double[a.length][a[0].length];
        for(int i=0;i<t.length;i++){
            for(int j=0;j<t[0].length;j++){
                try{
                    t[i][j]=Double.parseDouble(a[i][j]);
                            }
                catch(Exception e){

                }
            }
        }
        result[0]=G1(t);
        result[1]=G2(t);
        result[2]=G3(t);
        return result;
    }
    //收集三个基分类器结果存入result
    public String[]G1(double[][]a){
        String[]G1=new String[a.length];
        for(int i=0;i<a.length;i++){
            if(a[i][0]<2.5){
                G1[i]="1";
            }
            else{
                G1[i]="-1";
            }
        }
        return G1;
    }
    //以2.5为界,大于2.5的为-1,小于2.5的为1

    public String[]G2(double[][]b){
        String[]G2=new String[b.length];
        for(int i=0;i<b.length;i++){
            if(b[i][0]>5.5){
                G2[i]="1";
            }
            else{
                G2[i]="-1";
            }
        }
        return G2;
    }
    //以5.5为界,大于5.5的为1,小于5.5的为-1

    public String[]G3(double[][]c){
        String[]G3=new String[c.length];
        for(int i=0;i<c.length;i++){
            if(c[i][0]<8.5){
                G3[i]="1";
            }
            else{
                G3[i]="-1";
            }
        }
        return G3;
    }
    //以8.5为界,小于8.5的为1,大约8.5的为-1
}

package adaBoost;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Scanner;

public class InputStringData {
           int countRow=0,countCol=0,temp=0;
    public  String[][] loadData(String trainfile)throws IOException{
           ArrayList<String>features = new ArrayList<String>();
           File file = new File("C:\\Users\\CJH\\Desktop\\R程序运行",trainfile);
           Scanner input1 = new Scanner(file);
           while(input1.hasNext()){
               String line = input1.nextLine();
               Scanner input2 = new Scanner(line);
               countRow++;input2.close();
           }
           Scanner input11 = new Scanner(file);
           int[] length=new int[countRow];
           int i=0;
           while(input11.hasNext()){
               String line = input11.nextLine();
               Scanner input2 = new Scanner(line);
               temp=0;
               while(input2.hasNext()){
               features.add(input2.next());
               temp++;
               }
               if(countCol<temp){
                   countCol=temp;
               }
               length[i]=temp;
               i++;
               input2.close();
           }
           input11.close();
           String [][]x = new String[countRow][countCol];
           int index=0;
           for(int i1=0;i1<countRow;i1++){
               for(int j=0;j<countCol;j++){
                   if(length[i1]<=j){
                       x[i1][j]="null";
                   }
                   else{
                       x[i1][j]=features.get(index);
                       index++;
                   }
               }
           }
   return x;
}
}
//输入原始数据

实验结果及实例分析

0 1 2 3 4 5 6 7 8 9
1 1 1 -1 -1 -1 1 1 1 -1
第1个基分类器权重:0.423648930186459
第2个基分类器权重:0.7520386982992482
第3个基分类器权重:0.6496414921651305

权重变化过程:
添加第0个基分类器:

样本1 样本2 样本3 样本4 样本5 样本6 样本7 样本8 样本9 样本10
分类器1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
分类器3 0.071 0.071 0.071 0.071 0.071 0.167 0.167 0.167 0.167
分类器2 0.045 0.045 0.045 0.167 0.167 0.167 0.106 0.106 0.106

设置初始权重0.1,经过计算后,第一个分类器的误差最小,6、7、8分错,误差率为0.3。经过计算α1=0.4236
经过第二轮权重计算,增加了样本6、7、8的重要性,因此其它样本降低为0.071,6、7、8权重上升为0.167
第三个基分类器将3、4、5分错了,误差为0.2143,系数计算0.6496。因此增加3、4、5的权重。经过计算α3=0.6496
第二个基分类器误差为0.1820,经过计算α2=0.7514

发布了7 篇原创文章 · 获赞 3 · 访问量 5945
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 大白 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览