AdaBoost装袋提升算法

参开资料:http://blog.csdn.net/haidao2009/article/details/7514787
更多挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm

介绍

在介绍AdaBoost算法之前,需要了解一个类似的算法,装袋算法(bagging),bagging是一种提高分类准确率的算法,通过给定组合投票的方式,获得最优解。比如你生病了,去n个医院看了n个医生,每个医生给你开了药方,最后的结果中,哪个药方的出现的次数多,那就说明这个药方就越有可能性是最由解,这个很好理解。而bagging算法就是这个思想。

算法原理

而AdaBoost算法的核心思想还是基于bagging算法,但是他又一点点的改进,上面的每个医生的投票结果都是一样的,说明地位平等,如果在这里加上一个权重,大城市的医生权重高点,小县城的医生权重低,这样通过最终计算权重和的方式,会更加的合理,这就是AdaBoost算法。AdaBoost算法是一种迭代算法,只有最终分类误差率小于阈值算法才能停止,针对同一训练集数据训练不同的分类器,我们称弱分类器,最后按照权重和的形式组合起来,构成一个组合分类器,就是一个强分类器了。算法的只要过程:

1、对D训练集数据训练处一个分类器Ci

2、通过分类器Ci对数据进行分类,计算此时误差率

3、把上步骤中的分错的数据的权重提高,分对的权重降低,以此凸显了分错的数据。为什么这么做呢,后面会做出解释。

完整的adaboost算法如下


最后的sign函数是符号函数,如果最后的值为正,则分为+1类,否则即使-1类。

我们举个例子代入上面的过程,这样能够更好的理解。

adaboost的实现过程:

  图中,“+”和“-”分别表示两种类别,在这个过程中,我们使用水平或者垂直的直线作为分类器,来进行分类。

  第一步:

  根据分类的正确率,得到一个新的样本分布D,一个子分类器h1

  其中划圈的样本表示被分错的。在右边的途中,比较大的“+”表示对该样本做了加权。

算法最开始给了一个均匀分布 D 。所以h1 里的每个点的值是0.1。ok,当划分后,有三个点划分错了,根据算法误差表达式得到 误差为分错了的三个点的值之和,所以ɛ1=(0.1+0.1+0.1)=0.3,而ɑ1 根据表达式的可以算出来为0.42. 然后就根据算法 把分错的点权值变大。如此迭代,最终完成adaboost算法。

  第二步:

  根据分类的正确率,得到一个新的样本分布D3,一个子分类器h2

  第三步:

  得到一个子分类器h3

  整合所有子分类器:

  因此可以得到整合的结果,从结果中看,及时简单的分类器,组合起来也能获得很好的分类效果,在例子中所有的。后面的代码实现时,举出的也是这个例子,可以做对比,这里有一点比较重要,就是点的权重经过大小变化之后,需要进行归一化,确保总和为1.0,这个容易遗忘。

算法的代码实现

输入测试数据,与上图的例子相对应(数据格式:x坐标 y坐标 已分类结果):

  1. 151
  2. 231
  3. 31-1
  4. 45-1
  5. 561
  6. 64-1
  7. 671
  8. 761
  9. 87-1
  10. 82-1

Point.java

  1. packageDataMining_AdaBoost;
  2. /**
  3. *坐标点类
  4. *
  5. *@authorlyq
  6. *
  7. */
  8. publicclassPoint{
  9. //坐标点x坐标
  10. privateintx;
  11. //坐标点y坐标
  12. privateinty;
  13. //坐标点的分类类别
  14. privateintclassType;
  15. //如果此节点被划错,他的误差率,不能用个数除以总数,因为不同坐标点的权重不一定相等
  16. privatedoubleprobably;
  17. publicPoint(intx,inty,intclassType){
  18. this.x=x;
  19. this.y=y;
  20. this.classType=classType;
  21. }
  22. publicPoint(Stringx,Stringy,StringclassType){
  23. this.x=Integer.parseInt(x);
  24. this.y=Integer.parseInt(y);
  25. this.classType=Integer.parseInt(classType);
  26. }
  27. publicintgetX(){
  28. returnx;
  29. }
  30. publicvoidsetX(intx){
  31. this.x=x;
  32. }
  33. publicintgetY(){
  34. returny;
  35. }
  36. publicvoidsetY(inty){
  37. this.y=y;
  38. }
  39. publicintgetClassType(){
  40. returnclassType;
  41. }
  42. publicvoidsetClassType(intclassType){
  43. this.classType=classType;
  44. }
  45. publicdoublegetProbably(){
  46. returnprobably;
  47. }
  48. publicvoidsetProbably(doubleprobably){
  49. this.probably=probably;
  50. }
  51. }
AdaBoost.java

  1. packageDataMining_AdaBoost;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.IOException;
  6. importjava.text.MessageFormat;
  7. importjava.util.ArrayList;
  8. importjava.util.HashMap;
  9. importjava.util.Map;
  10. /**
  11. *AdaBoost提升算法工具类
  12. *
  13. *@authorlyq
  14. *
  15. */
  16. publicclassAdaBoostTool{
  17. //分类的类别,程序默认为正类1和负类-1
  18. publicstaticfinalintCLASS_POSITIVE=1;
  19. publicstaticfinalintCLASS_NEGTIVE=-1;
  20. //事先假设的3个分类器(理论上应该重新对数据集进行训练得到)
  21. publicstaticfinalStringCLASSIFICATION1="X=2.5";
  22. publicstaticfinalStringCLASSIFICATION2="X=7.5";
  23. publicstaticfinalStringCLASSIFICATION3="Y=5.5";
  24. //分类器组
  25. publicstaticfinalString[]ClASSIFICATION=newString[]{
  26. CLASSIFICATION1,CLASSIFICATION2,CLASSIFICATION3};
  27. //分类权重组
  28. privatedouble[]CLASSIFICATION_WEIGHT;
  29. //测试数据文件地址
  30. privateStringfilePath;
  31. //误差率阈值
  32. privatedoubleerrorValue;
  33. //所有的数据点
  34. privateArrayList<Point>totalPoint;
  35. publicAdaBoostTool(StringfilePath,doubleerrorValue){
  36. this.filePath=filePath;
  37. this.errorValue=errorValue;
  38. readDataFile();
  39. }
  40. /**
  41. *从文件中读取数据
  42. */
  43. privatevoidreadDataFile(){
  44. Filefile=newFile(filePath);
  45. ArrayList<String[]>dataArray=newArrayList<String[]>();
  46. try{
  47. BufferedReaderin=newBufferedReader(newFileReader(file));
  48. Stringstr;
  49. String[]tempArray;
  50. while((str=in.readLine())!=null){
  51. tempArray=str.split("");
  52. dataArray.add(tempArray);
  53. }
  54. in.close();
  55. }catch(IOExceptione){
  56. e.getStackTrace();
  57. }
  58. Pointtemp;
  59. totalPoint=newArrayList<>();
  60. for(String[]array:dataArray){
  61. temp=newPoint(array[0],array[1],array[2]);
  62. temp.setProbably(1.0/dataArray.size());
  63. totalPoint.add(temp);
  64. }
  65. }
  66. /**
  67. *根据当前的误差值算出所得的权重
  68. *
  69. *@paramerrorValue
  70. *当前划分的坐标点误差率
  71. *@return
  72. */
  73. privatedoublecalculateWeight(doubleerrorValue){
  74. doublealpha=0;
  75. doubletemp=0;
  76. temp=(1-errorValue)/errorValue;
  77. alpha=0.5*Math.log(temp);
  78. returnalpha;
  79. }
  80. /**
  81. *计算当前划分的误差率
  82. *
  83. *@parampointMap
  84. *划分之后的点集
  85. *@paramweight
  86. *本次划分得到的分类器权重
  87. *@return
  88. */
  89. privatedoublecalculateErrorValue(
  90. HashMap<Integer,ArrayList<Point>>pointMap){
  91. doubleresultValue=0;
  92. doubletemp=0;
  93. doubleweight=0;
  94. inttempClassType;
  95. ArrayList<Point>pList;
  96. for(Map.Entryentry:pointMap.entrySet()){
  97. tempClassType=(int)entry.getKey();
  98. pList=(ArrayList<Point>)entry.getValue();
  99. for(Pointp:pList){
  100. temp=p.getProbably();
  101. //如果划分类型不相等,代表划错了
  102. if(tempClassType!=p.getClassType()){
  103. resultValue+=temp;
  104. }
  105. }
  106. }
  107. weight=calculateWeight(resultValue);
  108. for(Map.Entryentry:pointMap.entrySet()){
  109. tempClassType=(int)entry.getKey();
  110. pList=(ArrayList<Point>)entry.getValue();
  111. for(Pointp:pList){
  112. temp=p.getProbably();
  113. //如果划分类型不相等,代表划错了
  114. if(tempClassType!=p.getClassType()){
  115. //划错的点的权重比例变大
  116. temp*=Math.exp(weight);
  117. p.setProbably(temp);
  118. }else{
  119. //划对的点的权重比减小
  120. temp*=Math.exp(-weight);
  121. p.setProbably(temp);
  122. }
  123. }
  124. }
  125. //如果误差率没有小于阈值,继续处理
  126. dataNormalized();
  127. returnresultValue;
  128. }
  129. /**
  130. *概率做归一化处理
  131. */
  132. privatevoiddataNormalized(){
  133. doublesumProbably=0;
  134. doubletemp=0;
  135. for(Pointp:totalPoint){
  136. sumProbably+=p.getProbably();
  137. }
  138. //归一化处理
  139. for(Pointp:totalPoint){
  140. temp=p.getProbably();
  141. p.setProbably(temp/sumProbably);
  142. }
  143. }
  144. /**
  145. *用AdaBoost算法得到的组合分类器对数据进行分类
  146. *
  147. */
  148. publicvoidadaBoostClassify(){
  149. doublevalue=0;
  150. Pointp;
  151. calculateWeightArray();
  152. for(inti=0;i<ClASSIFICATION.length;i++){
  153. System.out.println(MessageFormat.format("分类器{0}权重为:{1}",(i+1),CLASSIFICATION_WEIGHT[i]));
  154. }
  155. for(intj=0;j<totalPoint.size();j++){
  156. p=totalPoint.get(j);
  157. value=0;
  158. for(inti=0;i<ClASSIFICATION.length;i++){
  159. value+=1.0*classifyData(ClASSIFICATION[i],p)
  160. *CLASSIFICATION_WEIGHT[i];
  161. }
  162. //进行符号判断
  163. if(value>0){
  164. System.out
  165. .println(MessageFormat.format(
  166. "点({0},{1})的组合分类结果为:1,该点的实际分类为{2}",p.getX(),p.getY(),
  167. p.getClassType()));
  168. }else{
  169. System.out.println(MessageFormat.format(
  170. "点({0},{1})的组合分类结果为:-1,该点的实际分类为{2}",p.getX(),p.getY(),
  171. p.getClassType()));
  172. }
  173. }
  174. }
  175. /**
  176. *计算分类器权重数组
  177. */
  178. privatevoidcalculateWeightArray(){
  179. inttempClassType=0;
  180. doubleerrorValue=0;
  181. ArrayList<Point>posPointList;
  182. ArrayList<Point>negPointList;
  183. HashMap<Integer,ArrayList<Point>>mapList;
  184. CLASSIFICATION_WEIGHT=newdouble[ClASSIFICATION.length];
  185. for(inti=0;i<CLASSIFICATION_WEIGHT.length;i++){
  186. mapList=newHashMap<>();
  187. posPointList=newArrayList<>();
  188. negPointList=newArrayList<>();
  189. for(Pointp:totalPoint){
  190. tempClassType=classifyData(ClASSIFICATION[i],p);
  191. if(tempClassType==CLASS_POSITIVE){
  192. posPointList.add(p);
  193. }else{
  194. negPointList.add(p);
  195. }
  196. }
  197. mapList.put(CLASS_POSITIVE,posPointList);
  198. mapList.put(CLASS_NEGTIVE,negPointList);
  199. if(i==0){
  200. //最开始的各个点的权重一样,所以传入0,使得e的0次方等于1
  201. errorValue=calculateErrorValue(mapList);
  202. }else{
  203. //每次把上次计算所得的权重代入,进行概率的扩大或缩小
  204. errorValue=calculateErrorValue(mapList);
  205. }
  206. //计算当前分类器的所得权重
  207. CLASSIFICATION_WEIGHT[i]=calculateWeight(errorValue);
  208. }
  209. }
  210. /**
  211. *用各个子分类器进行分类
  212. *
  213. *@paramclassification
  214. *分类器名称
  215. *@paramp
  216. *待划分坐标点
  217. *@return
  218. */
  219. privateintclassifyData(Stringclassification,Pointp){
  220. //分割线所属坐标轴
  221. Stringposition;
  222. //分割线的值
  223. doublevalue=0;
  224. doubleposProbably=0;
  225. doublenegProbably=0;
  226. //划分是否是大于一边的划分
  227. booleanisLarger=false;
  228. String[]array;
  229. ArrayList<Point>pList=newArrayList<>();
  230. array=classification.split("=");
  231. position=array[0];
  232. value=Double.parseDouble(array[1]);
  233. if(position.equals("X")){
  234. if(p.getX()>value){
  235. isLarger=true;
  236. }
  237. //将训练数据中所有属于这边的点加入
  238. for(Pointpoint:totalPoint){
  239. if(isLarger&&point.getX()>value){
  240. pList.add(point);
  241. }elseif(!isLarger&&point.getX()<value){
  242. pList.add(point);
  243. }
  244. }
  245. }elseif(position.equals("Y")){
  246. if(p.getY()>value){
  247. isLarger=true;
  248. }
  249. //将训练数据中所有属于这边的点加入
  250. for(Pointpoint:totalPoint){
  251. if(isLarger&&point.getY()>value){
  252. pList.add(point);
  253. }elseif(!isLarger&&point.getY()<value){
  254. pList.add(point);
  255. }
  256. }
  257. }
  258. for(Pointp2:pList){
  259. if(p2.getClassType()==CLASS_POSITIVE){
  260. posProbably++;
  261. }else{
  262. negProbably++;
  263. }
  264. }
  265. //分类按正负类数量进行划分
  266. if(posProbably>negProbably){
  267. returnCLASS_POSITIVE;
  268. }else{
  269. returnCLASS_NEGTIVE;
  270. }
  271. }
  272. }
调用类Client.java:

  1. /**
  2. *AdaBoost提升算法调用类
  3. *@authorlyq
  4. *
  5. */
  6. publicclassClient{
  7. publicstaticvoidmain(String[]agrs){
  8. StringfilePath="C:\\Users\\lyq\\Desktop\\icon\\input.txt";
  9. //误差率阈值
  10. doubleerrorValue=0.2;
  11. AdaBoostTooltool=newAdaBoostTool(filePath,errorValue);
  12. tool.adaBoostClassify();
  13. }
  14. }

输出结果:

  1. 分类器1权重为:0.424
  2. 分类器2权重为:0.65
  3. 分类器3权重为:0.923
  4. 点(1,5)的组合分类结果为:1,该点的实际分类为1
  5. 点(2,3)的组合分类结果为:1,该点的实际分类为1
  6. 点(3,1)的组合分类结果为:-1,该点的实际分类为-1
  7. 点(4,5)的组合分类结果为:-1,该点的实际分类为-1
  8. 点(5,6)的组合分类结果为:1,该点的实际分类为1
  9. 点(6,4)的组合分类结果为:-1,该点的实际分类为-1
  10. 点(6,7)的组合分类结果为:1,该点的实际分类为1
  11. 点(7,6)的组合分类结果为:1,该点的实际分类为1
  12. 点(8,7)的组合分类结果为:-1,该点的实际分类为-1
  13. 点(8,2)的组合分类结果为:-1,该点的实际分类为-1

我们可以看到,如果3个分类单独分类,都没有百分百分对,而尽管组合结果之后,全部分类正确。

我对AdaBoost算法的理解

到了算法的末尾,有必要解释一下每次分类自后需要把错的点的权重增大,正确的减少的理由了,加入上次分类之后,(1,5)已经分错了,如果这次又分错,由于上次的权重已经提升,所以误差率更大,则代入公式ln(1-误差率/误差率)所得的权重越小,也就是说,如果同个数据,你分类的次数越多,你的权重越小,所以这就造成整体好的分类器的权重会越大,内部就会同时有各种权重的分类器,形成了一种互补的结果,如果好的分类器结果分错 ,可以由若干弱一点的分类器进行弥补。

AdaBoost算法的应用

可以运用在诸如特征识别,二分类的一些应用上,与单个模型相比,组合的形式能显著的提高准确率。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值