决策树算法(matlab)

本文转载自: 点击打开链接


    决策树是一种特别简单的机器学习分类算法。决策树想法来源于人类的决策过程。举个最简单的例子,人类发现下雨的时候,往往会有刮东风,然后天色变暗。对应于决策树模型,预测天气模型中的刮东风和天色变暗就是我们收集的特征,是否下雨就是类别标签。构建的决策树如下图所示


      决策树模型构建过程为,在特征集合中无放回的依次递归抽选特征作为决策树的节点——当前节点信息增益或者增益率最大,当前节点的值作为当前节点分支出来的有向边(实际上主要选择的是这些边,这个由信息增益的计算公式就可以得到)。对于这个进行直观解释


      来说一个极端情况,如果有一个特征下,特征取不同值的时候,对应的类别标签都是纯的,决策者肯定会选择这个特征,作为鉴别未知数据的判别准则。由下面的计算信息增益的公式可以发现这时候对应的信息增益是最大的。

g(D,A)=H(D)-H(D|A)

      g(D,A):表示特征A对训练数据集D的信息增益

         H(D):表示数据集合D的经验熵

      H(D|A):表示特征A给定条件下数据集合D的条件熵。


      反之,当某个特征它的各个取值下对应的类别标签均匀分布的时候H(D|A)最大,又对于所有的特征H(D)是都一样的。因此,这时候的g(D,A)最小。

      总之一句话,我们要挑选的特征是:当前特征下各个取值包含的分类信息最明确。


下面我们来看一个MATLAB编写的决策树算法,帮助理解


[plain] view plaincopy
  1. clear;clc;  
  2.   
  3. % OutlookType=struct('Sunny',1,'Rainy',2,'Overcast',3);  
  4. % TemperatureType=struct('hot',1,'warm',2,'cool',3);  
  5. % HumidityType=struct('high',1,'norm',2);  
  6. % WindyType={'True',1,'False',0};  
  7. % PlayGolf={'Yes',1,'No',0};  
  8. % data=struct('Outlook',[],'Temperature',[],'Humidity',[],'Windy',[],'PlayGolf',[]);  
  9.   
  10. Outlook=[1,1,3,2,2,2,3,1,1,2,1,3,3,2]';  
  11. Temperature=[1,1,1,2,3,3,3,2,3,3,2,2,1,2]';  
  12. Humidity=[1,1,1,1,2,2,2,1,2,2,2,1,2,1]';  
  13. Windy=[0,1,0,0,0,1,1,0,0,0,1,1,0,1]';  
  14.   
  15. data=[Outlook Temperature Humidity Windy];  
  16. PlayGolf=[0,0,1,1,1,0,1,0,1,1,1,1,1,0]';  
  17. propertyName={'Outlook','Temperature','Humidity','Windy'};  
  18. delta=0.1;  
  19. decisionTreeModel=decisionTree(data,PlayGolf,propertyName,delta);  

[plain] view plaincopy
  1. function decisionTreeModel=decisionTree(data,label,propertyName,delta)  
  2.   
  3. global Node;  
  4.   
  5. Node=struct('fatherNodeName',[],'EdgeProperty',[],'NodeName',[]);  
  6. BuildTree('root','Stem',data,label,propertyName,delta);  
  7. Node(1)=[];  
  8. model.Node=Node;  
  9. decisionTreeModel=model;  

[plain] view plaincopy
  1. function BuildTree(fatherNodeName,edge,data,label,propertyName,delta)  
  2. %UNTITLED9 Summary of this function goes here  
  3. %   Detailed explanation goes here  
  4.   
  5. global Node;  
  6. sonNode=struct('fatherNodeName',[],'EdgeProperty',[],'NodeName',[]);  
  7. sonNode.fatherNodeName=fatherNodeName;  
  8. sonNode.EdgeProperty=edge;  
  9. if length(unique(label))==1  
  10.     sonNode.NodeName=label(1);  
  11.     Node=[Node sonNode];  
  12.     return;  
  13. end  
  14. if length(propertyName)<1  
  15.     labelSet=unique(label);  
  16.     labelNum=length(labelSet);  
  17.     for i=1:labelNum  
  18.         labelNum=length(find(label==labelSet(i)));  
  19.     end  
  20.     [~,labelIndex]=max(labelNum);  
  21.     sonNode.NodeName=labelSet(labelIndex);  
  22.     Node=[Node sonNode];  
  23.     return;  
  24. end  
  25. [sonIndex,BuildNode]=CalcuteNode(data,label,delta);  
  26. if BuildNode  
  27.     dataRowIndex=setdiff(1:length(propertyName),sonIndex);  
  28.     sonNode.NodeName=propertyName(sonIndex);  
  29.     Node=[Node sonNode];  
  30.     propertyName(sonIndex)=[];  
  31.     sonData=data(:,sonIndex);  
  32.     sonEdge=unique(sonData);  
  33.       
  34.     for i=1:length(sonEdge)  
  35.         edgeDataIndex=find(sonData==sonEdge(i));  
  36.         BuildTree(sonNode.NodeName,sonEdge(i),data(edgeDataIndex,dataRowIndex),label(edgeDataIndex,:),propertyName,delta);  
  37.     end  
  38. else  
  39.     labelSet=unique(label);  
  40.     labelNum=length(labelSet);  
  41.     for i=1:labelNum  
  42.         labelNum=length(find(label==labelSet(i)));  
  43.     end  
  44.     [~,labelIndex]=max(labelNum);  
  45.     sonNode.NodeName=labelSet(labelIndex);  
  46.     Node=[Node sonNode];  
  47.     return;  
  48. end  

[plain] view plaincopy
  1. function [NodeIndex,BuildNode]=CalcuteNode(data,label,delta)  
  2.   
  3. LargeEntropy=CEntropy(label);  
  4. [m,n]=size(data);  
  5. EntropyGain=LargeEntropy*ones(1,n);  
  6. BuildNode=true;  
  7. for i=1:n  
  8.     pData=data(:,i);  
  9.     itemList=unique(pData);  
  10.     for j=1:length(itemList)  
  11.         itemIndex=find(pData==itemList(j));  
  12.         EntropyGain(i)=EntropyGain(i)-length(itemIndex)/m*CEntropy(label(itemIndex));  
  13.     end  
  14.     % 此处运行则为增益率,注释掉则为增益  
  15.     % EntropyGain(i)=EntropyGain(i)/CEntropy(pData);   
  16. end  
  17. [maxGainEntropy,NodeIndex]=max(EntropyGain);  
  18. if maxGainEntropy<delta  
  19.     BuildNode=false;  
  20. end  

[plain] view plaincopy
  1. function result=CEntropy(propertyList)  
  2.   
  3. result=0;  
  4. totalLength=length(propertyList);  
  5. itemList=unique(propertyList);  
  6. pNum=length(itemList);  
  7. for i=1:pNum  
  8.     itemLength=length(find(propertyList==itemList(i)));  
  9.     pItem=itemLength/totalLength;  
  10.     result=result-pItem*log2(pItem);  
  11. end  

版权声明:本文为博主原创文章,未经博主允许不得转载

展开阅读全文

没有更多推荐了,返回首页