决策分类树算法之ID3,C4.5算法系列

一、引言

在最开始的时候,我本来准备学习的是C4.5算法,后来发现C4.5算法的核心还是ID3算法,所以又辗转回到学习ID3算法了,因为C4.5是他的一个改进。至于是什么改进,在后面的描述中我会提到。

二、ID3算法

ID3算法是一种分类决策树算法。他通过一系列的规则,将数据最后分类成决策树的形式。分类的根据是用到了熵这个概念。熵在物理这门学科中就已经出现过,表示是一个物质的稳定度,在这里就是分类的纯度的一个概念。公式为:


在ID3算法中,是采用Gain信息增益来作为一个分类的判定标准的。他的定义为:


每次选择属性中信息增益最大作为划分属性,在这里本人实现了一个java版本的ID3算法,为了模拟数据的可操作性,就把数据写到一个input.txt文件中,作为数据源,格式如下:

  1. DayOutLookTemperatureHumidityWindPlayTennis
  2. 1SunnyHotHighWeakNo
  3. 2SunnyHotHighStrongNo
  4. 3OvercastHotHighWeakYes
  5. 4RainyMildHighWeakYes
  6. 5RainyCoolNormalWeakYes
  7. 6RainyCoolNormalStrongNo
  8. 7OvercastCoolNormalStrongYes
  9. 8SunnyMildHighWeakNo
  10. 9SunnyCoolNormalWeakYes
  11. 10RainyMildNormalWeakYes
  12. 11SunnyMildNormalStrongYes
  13. 12OvercastMildHighStrongYes
  14. 13OvercastHotNormalWeakYes
  15. 14RainyMildHighStrongNo
PalyTennis属性为结构属性,是作为类标识用的,中间的OutLool,Temperature,Humidity,Wind才是划分属性,通过将源数据与执行程序分类,这样可以模拟巨大的数据量了。下面是ID3的主程序类,本人将ID3的算法进行了包装,对外只开放了一个构建决策树的方法,在构造函数时候,只需传入一个数据路径文件即可:

  1. packageDataMing_ID3;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.IOException;
  6. importjava.util.ArrayList;
  7. importjava.util.HashMap;
  8. importjava.util.Iterator;
  9. importjava.util.Map;
  10. importjava.util.Map.Entry;
  11. importjava.util.Set;
  12. /**
  13. *ID3算法实现类
  14. *
  15. *@authorlyq
  16. *
  17. */
  18. publicclassID3Tool{
  19. //类标号的值类型
  20. privatefinalStringYES="Yes";
  21. privatefinalStringNO="No";
  22. //所有属性的类型总数,在这里就是data源数据的列数
  23. privateintattrNum;
  24. privateStringfilePath;
  25. //初始源数据,用一个二维字符数组存放模仿表格数据
  26. privateString[][]data;
  27. //数据的属性行的名字
  28. privateString[]attrNames;
  29. //每个属性的值所有类型
  30. privateHashMap<String,ArrayList<String>>attrValue;
  31. publicID3Tool(StringfilePath){
  32. this.filePath=filePath;
  33. attrValue=newHashMap<>();
  34. }
  35. /**
  36. *从文件中读取数据
  37. */
  38. privatevoidreadDataFile(){
  39. Filefile=newFile(filePath);
  40. ArrayList<String[]>dataArray=newArrayList<String[]>();
  41. try{
  42. BufferedReaderin=newBufferedReader(newFileReader(file));
  43. Stringstr;
  44. String[]tempArray;
  45. while((str=in.readLine())!=null){
  46. tempArray=str.split("");
  47. dataArray.add(tempArray);
  48. }
  49. in.close();
  50. }catch(IOExceptione){
  51. e.getStackTrace();
  52. }
  53. data=newString[dataArray.size()][];
  54. dataArray.toArray(data);
  55. attrNum=data[0].length;
  56. attrNames=data[0];
  57. /*
  58. *for(inti=0;i<data.length;i++){for(intj=0;j<data[0].length;j++){
  59. *System.out.print(""+data[i][j]);}
  60. *
  61. *System.out.print("\n");}
  62. */
  63. }
  64. /**
  65. *首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
  66. */
  67. privatevoidinitAttrValue(){
  68. ArrayList<String>tempValues;
  69. //按照列的方式,从左往右找
  70. for(intj=1;j<attrNum;j++){
  71. //从一列中的上往下开始寻找值
  72. tempValues=newArrayList<>();
  73. for(inti=1;i<data.length;i++){
  74. if(!tempValues.contains(data[i][j])){
  75. //如果这个属性的值没有添加过,则添加
  76. tempValues.add(data[i][j]);
  77. }
  78. }
  79. //一列属性的值已经遍历完毕,复制到map属性表中
  80. attrValue.put(data[0][j],tempValues);
  81. }
  82. /*
  83. *for(Map.Entryentry:attrValue.entrySet()){
  84. *System.out.println("key:value"+entry.getKey()+":"+
  85. *entry.getValue());}
  86. */
  87. }
  88. /**
  89. *计算数据按照不同方式划分的熵
  90. *
  91. *@paramremainData
  92. *剩余的数据
  93. *@paramattrName
  94. *待划分的属性,在算信息增益的时候会使用到
  95. *@paramattrValue
  96. *划分的子属性值
  97. *@paramisParent
  98. *是否分子属性划分还是原来不变的划分
  99. */
  100. privatedoublecomputeEntropy(String[][]remainData,StringattrName,
  101. Stringvalue,booleanisParent){
  102. //实例总数
  103. inttotal=0;
  104. //正实例数
  105. intposNum=0;
  106. //负实例数
  107. intnegNum=0;
  108. //还是按列从左往右遍历属性
  109. for(intj=1;j<attrNames.length;j++){
  110. //找到了指定的属性
  111. if(attrName.equals(attrNames[j])){
  112. for(inti=1;i<remainData.length;i++){
  113. //如果是父结点直接计算熵或者是通过子属性划分计算熵,这时要进行属性值的过滤
  114. if(isParent
  115. ||(!isParent&&remainData[i][j].equals(value))){
  116. if(remainData[i][attrNames.length-1].equals(YES)){
  117. //判断此行数据是否为正实例
  118. posNum++;
  119. }else{
  120. negNum++;
  121. }
  122. }
  123. }
  124. }
  125. }
  126. total=posNum+negNum;
  127. doubleposProbobly=(double)posNum/total;
  128. doublenegProbobly=(double)negNum/total;
  129. if(posProbobly==1||posProbobly==0){
  130. //如果数据全为同种类型,则熵为0,否则带入下面的公式会报错
  131. return0;
  132. }
  133. doubleentropyValue=-posProbobly*Math.log(posProbobly)
  134. /Math.log(2.0)-negProbobly*Math.log(negProbobly)
  135. /Math.log(2.0);
  136. //返回计算所得熵
  137. returnentropyValue;
  138. }
  139. /**
  140. *为某个属性计算信息增益
  141. *
  142. *@paramremainData
  143. *剩余的数据
  144. *@paramvalue
  145. *待划分的属性名称
  146. *@return
  147. */
  148. privatedoublecomputeGain(String[][]remainData,Stringvalue){
  149. doublegainValue=0;
  150. //源熵的大小将会与属性划分后进行比较
  151. doubleentropyOri=0;
  152. //子划分熵和
  153. doublechildEntropySum=0;
  154. //属性子类型的个数
  155. intchildValueNum=0;
  156. //属性值的种数
  157. ArrayList<String>attrTypes=attrValue.get(value);
  158. //子属性对应的权重比
  159. HashMap<String,Integer>ratioValues=newHashMap<>();
  160. for(inti=0;i<attrTypes.size();i++){
  161. //首先都统一计数为0
  162. ratioValues.put(attrTypes.get(i),0);
  163. }
  164. //还是按照一列,从左往右遍历
  165. for(intj=1;j<attrNames.length;j++){
  166. //判断是否到了划分的属性列
  167. if(value.equals(attrNames[j])){
  168. for(inti=1;i<=remainData.length-1;i++){
  169. childValueNum=ratioValues.get(remainData[i][j]);
  170. //增加个数并且重新存入
  171. childValueNum++;
  172. ratioValues.put(remainData[i][j],childValueNum);
  173. }
  174. }
  175. }
  176. //计算原熵的大小
  177. entropyOri=computeEntropy(remainData,value,null,true);
  178. for(inti=0;i<attrTypes.size();i++){
  179. doubleratio=(double)ratioValues.get(attrTypes.get(i))
  180. /(remainData.length-1);
  181. childEntropySum+=ratio
  182. *computeEntropy(remainData,value,attrTypes.get(i),false);
  183. //System.out.println("ratio:value:"+ratio+""+
  184. //computeEntropy(remainData,value,
  185. //attrTypes.get(i),false));
  186. }
  187. //二者熵相减就是信息增益
  188. gainValue=entropyOri-childEntropySum;
  189. returngainValue;
  190. }
  191. /**
  192. *计算信息增益比
  193. *
  194. *@paramremainData
  195. *剩余数据
  196. *@paramvalue
  197. *待划分属性
  198. *@return
  199. */
  200. privatedoublecomputeGainRatio(String[][]remainData,Stringvalue){
  201. doublegain=0;
  202. doublespiltInfo=0;
  203. intchildValueNum=0;
  204. //属性值的种数
  205. ArrayList<String>attrTypes=attrValue.get(value);
  206. //子属性对应的权重比
  207. HashMap<String,Integer>ratioValues=newHashMap<>();
  208. for(inti=0;i<attrTypes.size();i++){
  209. //首先都统一计数为0
  210. ratioValues.put(attrTypes.get(i),0);
  211. }
  212. //还是按照一列,从左往右遍历
  213. for(intj=1;j<attrNames.length;j++){
  214. //判断是否到了划分的属性列
  215. if(value.equals(attrNames[j])){
  216. for(inti=1;i<=remainData.length-1;i++){
  217. childValueNum=ratioValues.get(remainData[i][j]);
  218. //增加个数并且重新存入
  219. childValueNum++;
  220. ratioValues.put(remainData[i][j],childValueNum);
  221. }
  222. }
  223. }
  224. //计算信息增益
  225. gain=computeGain(remainData,value);
  226. //计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):
  227. for(inti=0;i<attrTypes.size();i++){
  228. doubleratio=(double)ratioValues.get(attrTypes.get(i))
  229. /(remainData.length-1);
  230. spiltInfo+=-ratio*Math.log(ratio)/Math.log(2.0);
  231. }
  232. //计算机信息增益率
  233. returngain/spiltInfo;
  234. }
  235. /**
  236. *利用源数据构造决策树
  237. */
  238. privatevoidbuildDecisionTree(AttrNodenode,StringparentAttrValue,
  239. String[][]remainData,ArrayList<String>remainAttr,booleanisID3){
  240. node.setParentAttrValue(parentAttrValue);
  241. StringattrName="";
  242. doublegainValue=0;
  243. doubletempValue=0;
  244. //如果只有1个属性则直接返回
  245. if(remainAttr.size()==1){
  246. System.out.println("attrnull");
  247. return;
  248. }
  249. //选择剩余属性中信息增益最大的作为下一个分类的属性
  250. for(inti=0;i<remainAttr.size();i++){
  251. //判断是否用ID3算法还是C4.5算法
  252. if(isID3){
  253. //ID3算法采用的是按照信息增益的值来比
  254. tempValue=computeGain(remainData,remainAttr.get(i));
  255. }else{
  256. //C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
  257. tempValue=computeGainRatio(remainData,remainAttr.get(i));
  258. }
  259. if(tempValue>gainValue){
  260. gainValue=tempValue;
  261. attrName=remainAttr.get(i);
  262. }
  263. }
  264. node.setAttrName(attrName);
  265. ArrayList<String>valueTypes=attrValue.get(attrName);
  266. remainAttr.remove(attrName);
  267. AttrNode[]childNode=newAttrNode[valueTypes.size()];
  268. String[][]rData;
  269. for(inti=0;i<valueTypes.size();i++){
  270. //移除非此值类型的数据
  271. rData=removeData(remainData,attrName,valueTypes.get(i));
  272. childNode[i]=newAttrNode();
  273. booleansameClass=true;
  274. ArrayList<String>indexArray=newArrayList<>();
  275. for(intk=1;k<rData.length;k++){
  276. indexArray.add(rData[k][0]);
  277. //判断是否为同一类的
  278. if(!rData[k][attrNames.length-1]
  279. .equals(rData[1][attrNames.length-1])){
  280. //只要有1个不相等,就不是同类型的
  281. sameClass=false;
  282. break;
  283. }
  284. }
  285. if(!sameClass){
  286. //创建新的对象属性,对象的同个引用会出错
  287. ArrayList<String>rAttr=newArrayList<>();
  288. for(Stringstr:remainAttr){
  289. rAttr.add(str);
  290. }
  291. buildDecisionTree(childNode[i],valueTypes.get(i),rData,
  292. rAttr,isID3);
  293. }else{
  294. //如果是同种类型,则直接为数据节点
  295. childNode[i].setParentAttrValue(valueTypes.get(i));
  296. childNode[i].setChildDataIndex(indexArray);
  297. }
  298. }
  299. node.setChildAttrNode(childNode);
  300. }
  301. /**
  302. *属性划分完毕,进行数据的移除
  303. *
  304. *@paramsrcData
  305. *源数据
  306. *@paramattrName
  307. *划分的属性名称
  308. *@paramvalueType
  309. *属性的值类型
  310. */
  311. privateString[][]removeData(String[][]srcData,StringattrName,
  312. StringvalueType){
  313. String[][]desDataArray;
  314. ArrayList<String[]>desData=newArrayList<>();
  315. //待删除数据
  316. ArrayList<String[]>selectData=newArrayList<>();
  317. selectData.add(attrNames);
  318. //数组数据转化到列表中,方便移除
  319. for(inti=0;i<srcData.length;i++){
  320. desData.add(srcData[i]);
  321. }
  322. //还是从左往右一列列的查找
  323. for(intj=1;j<attrNames.length;j++){
  324. if(attrNames[j].equals(attrName)){
  325. for(inti=1;i<desData.size();i++){
  326. if(desData.get(i)[j].equals(valueType)){
  327. //如果匹配这个数据,则移除其他的数据
  328. selectData.add(desData.get(i));
  329. }
  330. }
  331. }
  332. }
  333. desDataArray=newString[selectData.size()][];
  334. selectData.toArray(desDataArray);
  335. returndesDataArray;
  336. }
  337. /**
  338. *开始构建决策树
  339. *
  340. *@paramisID3
  341. *是否采用ID3算法构架决策树
  342. */
  343. publicvoidstartBuildingTree(booleanisID3){
  344. readDataFile();
  345. initAttrValue();
  346. ArrayList<String>remainAttr=newArrayList<>();
  347. //添加属性,除了最后一个类标号属性
  348. for(inti=1;i<attrNames.length-1;i++){
  349. remainAttr.add(attrNames[i]);
  350. }
  351. AttrNoderootNode=newAttrNode();
  352. buildDecisionTree(rootNode,"",data,remainAttr,isID3);
  353. showDecisionTree(rootNode,1);
  354. }
  355. /**
  356. *显示决策树
  357. *
  358. *@paramnode
  359. *待显示的节点
  360. *@paramblankNum
  361. *行空格符,用于显示树型结构
  362. */
  363. privatevoidshowDecisionTree(AttrNodenode,intblankNum){
  364. System.out.println();
  365. for(inti=0;i<blankNum;i++){
  366. System.out.print("\t");
  367. }
  368. System.out.print("--");
  369. //显示分类的属性值
  370. if(node.getParentAttrValue()!=null
  371. &&node.getParentAttrValue().length()>0){
  372. System.out.print(node.getParentAttrValue());
  373. }else{
  374. System.out.print("--");
  375. }
  376. System.out.print("--");
  377. if(node.getChildDataIndex()!=null
  378. &&node.getChildDataIndex().size()>0){
  379. Stringi=node.getChildDataIndex().get(0);
  380. System.out.print("类别:"
  381. +data[Integer.parseInt(i)][attrNames.length-1]);
  382. System.out.print("[");
  383. for(Stringindex:node.getChildDataIndex()){
  384. System.out.print(index+",");
  385. }
  386. System.out.print("]");
  387. }else{
  388. //递归显示子节点
  389. System.out.print("【"+node.getAttrName()+"】");
  390. for(AttrNodechildNode:node.getChildAttrNode()){
  391. showDecisionTree(childNode,2*blankNum);
  392. }
  393. }
  394. }
  395. }
他的场景调用实现的方式为:

  1. /**
  2. *ID3决策树分类算法测试场景类
  3. *@authorlyq
  4. *
  5. */
  6. publicclassClient{
  7. publicstaticvoidmain(String[]args){
  8. StringfilePath="C:\\Users\\lyq\\Desktop\\icon\\input.txt";
  9. ID3Tooltool=newID3Tool(filePath);
  10. tool.startBuildingTree(true);
  11. }
  12. }
最终的结果为:

  1. ------【OutLook】
  2. --Sunny--【Humidity】
  3. --High--类别:No[1,2,8,]
  4. --Normal--类别:Yes[9,11,]
  5. --Overcast--类别:Yes[3,7,12,13,]
  6. --Rainy--【Wind】
  7. --Weak--类别:Yes[4,5,10,]
  8. --Strong--类别:No[6,14,]

请从左往右观察这棵决策树,【】里面的是分类属性,---XXX----,XXX为属性的值,在叶子节点处为类标记。

对应的分类结果图:


这里的构造决策树和显示决策树采用的DFS的方法,所以可能会比较难懂,希望读者能细细体会,可以调试一下代码,一步步的跟踪会更加容易理解的。

三、C4.5算法

如果你已经理解了上面ID3算法的实现,那么理解C4.5也很容易了,C4.5与ID3在核心的算法是一样的,但是有一点所采用的办法是不同的,C4.5采用了信息增益率作为划分的根据,克服了ID3算法中采用信息增益划分导致属性选择偏向取值多的属性。信息增益率的公式为:


分母的位置是分裂因子,他的计算公式为:


和熵的计算公式比较像,具体的信息增益率的算法也在上面的代码中了,请关注着2个方法:

  1. //选择剩余属性中信息增益最大的作为下一个分类的属性
  2. for(inti=0;i<remainAttr.size();i++){
  3. //判断是否用ID3算法还是C4.5算法
  4. if(isID3){
  5. //ID3算法采用的是按照信息增益的值来比
  6. tempValue=computeGain(remainData,remainAttr.get(i));
  7. }else{
  8. //C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
  9. tempValue=computeGainRatio(remainData,remainAttr.get(i));
  10. }
  11. if(tempValue>gainValue){
  12. gainValue=tempValue;
  13. attrName=remainAttr.get(i);
  14. }
  15. }
在补充一下C4.5在其他方面对ID3的补充和改进:

1、在构造决策树的过程中能对树进行剪枝。

2、能对连续性的值进行离散化的操作。

四、编码时遇到的一些问题

为了实现ID3算法,从理解阅读他的原理就已经用掉了比较多的时间,然后再尝试阅读别人写的C++版本的代码,又是看了几天,好不容易实现了2个算法,最后在构造树的过程中遇到了最大了麻烦,因为用到了递归构造树,对于其中节点的设计就显得至关重要了,也许我自己目前的设计也不是最优秀的。下面盘点一下我的程序的遇到的一些问题和存在的潜在的问题:

1、在构建决策树的时候,出现了remainAttr值缺少的情况,就是递归的时候remainAttr的属性划分移除掉之后,对于上次的递归操作的属性时受到影响了,后来发现是因为我remainAttr采用的是ArrayList,他是一个引用对象,通过引用传入的方式,对象用的还是同一个,所以果断重新建了一个ArrayList对象,问题就OK了。

  1. //创建新的对象属性,对象的同个引用会出错
  2. ArrayList<String>rAttr=newArrayList<>();
  3. for(Stringstr:remainAttr){
  4. rAttr.add(str);
  5. }
  6. buildDecisionTree(childNode[i],valueTypes.get(i),rData,
  7. rAttr,isID3);
2、第二个问题是当程序划分到最后一个属性时,如果出现了数据的类标识并不是同一个类的时候,我的处理操作时直接不处理,直接返回,会造成节点没有数据属性,也没有数据索引。

  1. privatevoidbuildDecisionTree(AttrNodenode,StringparentAttrValue,
  2. String[][]remainData,ArrayList<String>remainAttr,booleanisID3){
  3. node.setParentAttrValue(parentAttrValue);
  4. StringattrName="";
  5. doublegainValue=0;
  6. doubletempValue=0;
  7. //如果只有1个属性则直接返回
  8. if(remainAttr.size()==1){
  9. System.out.println("attrnull");
  10. return;
  11. }
  12. .....
在这种情况下的处理不是很恰当个人觉得是这样。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值