adaboost算法代码 java_AdaBoost的java实现

1 importjava.io.BufferedReader;2 importjava.io.FileInputStream;3 importjava.io.IOException;4 importjava.io.InputStreamReader;5 importjava.util.ArrayList;6

7 classStump{8 public intdim;9 public doublethresh;10 publicString condition;11 public doubleerror;12 public ArrayListlabelList;13 doublefactor;14

15 publicString toString(){16 return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList;17 }18 }19

20 classUtils{21 //加载数据集

22 public static ArrayList> loadDataSet(String filename) throwsIOException{23 ArrayList> dataSet=new ArrayList>();24 FileInputStream fis=newFileInputStream(filename);25 InputStreamReader isr=new InputStreamReader(fis,"UTF-8");26 BufferedReader br=newBufferedReader(isr);27 String line="";28

29 while((line=br.readLine())!=null){30 ArrayList data=new ArrayList();31 String[] s=line.split(" ");32

33 for(int i=0;i

41 //加载类别

42 public static ArrayList loadLabelSet(String filename) throwsNumberFormatException, IOException{43 ArrayList labelSet=new ArrayList();44

45 FileInputStream fis=newFileInputStream(filename);46 InputStreamReader isr=new InputStreamReader(fis,"UTF-8");47 BufferedReader br=newBufferedReader(isr);48 String line="";49

50 while((line=br.readLine())!=null){51 String[] s=line.split(" ");52 labelSet.add(Integer.parseInt(s[s.length-1]));53 }54 returnlabelSet;55 }56 //测试用的

57 public static void showDataSet(ArrayList>dataSet){58 for(ArrayListdata:dataSet){59 System.out.println(data);60 }61 }62 //获取最大值,用于求步长

63 public static double getMax(ArrayList> dataSet,intindex){64 double max=-9999.0;65 for(ArrayListdata:dataSet){66 if(data.get(index)>max){67 max=data.get(index);68 }69 }70 returnmax;71 }72 //获取最小值,用于求步长

73 public static double getMin(ArrayList> dataSet,intindex){74 double min=9999.0;75 for(ArrayListdata:dataSet){76 if(data.get(index)

83 //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别

84 public static ArrayList getClassify(ArrayList> dataSet,int feature,doublethresh,String condition){85 ArrayList labelList=new ArrayList();86 if(condition.compareTo("lt")==0){87 for(ArrayListdata:dataSet){88 if(data.get(feature)<=thresh){89 labelList.add(1);90 }else{91 labelList.add(-1);92 }93 }94 }else{95 for(ArrayListdata:dataSet){96 if(data.get(feature)>=thresh){97 labelList.add(1);98 }else{99 labelList.add(-1);100 }101 }102 }103 returnlabelList;104 }105 //求预测类别与真实类别的加权误差

106 public static double getError(ArrayList fake,ArrayList real,ArrayListweights){107 double error=0;108

109 int n=real.size();110

111 for(int i=0;i

115 }116 }117

118 returnerror;119 }120 //构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。

121 public static Stump buildStump(ArrayList> dataSet,ArrayList labelSet,ArrayList weights,intn){122 int featureNum=dataSet.get(0).size();123

124 int rowNum=dataSet.size();125 Stump stump=newStump();126 double minError=999.0;127 System.out.println("第"+n+"次迭代");128 for(int i=0;i

134 for(String condition:conditions){135 ArrayList labelList=getClassify(dataSet,i,j,condition);136

137 double error=Utils.getError(labelList,labelSet,weights);138 if(error

148 }149 }150

151 }152

153 returnstump;154 }155

156 public static ArrayList getInitWeights(intn){157 double weight=1.0/n;158 ArrayList weights=new ArrayList();159 for(int i=0;i

165 public static ArrayList updateWeights(Stump stump,ArrayList labelList,ArrayListweights){166 double Z=0;167 ArrayList newWeights=new ArrayList();168 int row=labelList.size();169 double e=Math.E;170 double factor=stump.factor;171 for(int i=0;i

175

176 for(int i=0;i

183 public static ArrayList InitAccWeightError(intn){184 ArrayList accError=new ArrayList();185 for(int i=0;i

191 public static ArrayList accWeightError(ArrayListaccerror,Stump stump){192 ArrayList t=stump.labelList;193 double factor=stump.factor;194 ArrayList newAccError=new ArrayList();195 for(int i=0;i

202 public static double calErrorRate(ArrayList accError,ArrayListlabelList){203 ArrayList a=new ArrayList();204 int wrong=0;205 for(int i=0;i0){207 if(labelList.get(i)==-1){208 wrong++;209 }210 }else if(labelList.get(i)==1){211 wrong++;212 }213 }214 double error=wrong*1.0/accError.size();215 returnerror;216 }217

218 public static void showStumpList(ArrayListG){219 for(Stump s:G){220 System.out.println(s);221 System.out.println(" ");222 }223 }224 }225

226

227 public classAdaboost {228

229 /**

230 *@paramargs231 *@throwsIOException232 */

233

234 public static ArrayList AdaBoostTrain(ArrayList> dataSet,ArrayListlabelList){235 int row=labelList.size();236 ArrayList weights=Utils.getInitWeights(row);237 ArrayList G=new ArrayList();238 ArrayList accError=Utils.InitAccWeightError(row);239 int n=1;240 while(true){241 Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树

242 G.add(stump);243 weights=Utils.updateWeights(stump,labelList,weights);//更新权值

244 accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了

245 double error=Utils.calErrorRate(accError,labelList);246 if(error<0.001){247 break;248 }249 n++;250 }251 returnG;252 }253

254 public static void main(String[] args) throwsIOException {255 //TODO Auto-generated method stub

256 String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt";257 ArrayList> dataSet=Utils.loadDataSet(file);258 ArrayList labelSet=Utils.loadLabelSet(file);259 ArrayList G=AdaBoostTrain(dataSet,labelSet);260 Utils.showStumpList(G);261 System.out.println("finished");262 }263

264 }

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值