决策树和随机森林学习笔记

1 衡量一个随机变量的不确定性

1.1 信息熵

H ( X ) H(X) H(X)是用于描述 X X X携带信息量的,信息量越大(值变化越多),越不确定,越不容易被预测。

X X X是一个取有限个值的离散随机变量,其概率分布为 P ( X = x i ) = p i , i = 1 , 2 , 3 , . . . , n P(X=x_i)=p_i,i=1,2,3,...,n P(X=xi)=pi,i=1,2,3,...,n,则随机变量 X X X的熵为:
H ( X ) = − ∑ i = 1 n p i l o g p i H(X)=-\sum_{i=1}^{n}{p_ilogp_i} H(X)=i=1npilogpi

1.2 Gini系数

在分类问题中,假设有 K K K个类,样本点属于第 k k k个类的概率为 p k p_k pk,该概率分布的基尼系数为:
G i n i ( p ) = ∑ k = 1 K p k ( 1 − p k ) = 1 − ∑ k = 1 K p k 2 Gini(p)=\sum_{k=1}^{K}{p_k(1-p_k)}=1-\sum_{k=1}^{K}{p_k^2} Gini(p)=k=1Kpk(1pk)=1k=1Kpk2
C k C_k Ck D D D中属于第 k k k类的样本子集,则基尼系数为:
G i n i ( D ) = 1 − ∑ k = 1 K ( ∣ C k ∣ ∣ D ∣ ) 2 Gini(D)=1-\sum_{k=1}^{K}{(\frac{|C_k|}{|D|})^2} Gini(D)=1k=1K(DCk)2
设条件 A A A将样本 D D D切分为 D 1 D_1 D1 D 1 D_1 D1两个数据子集,则在条件 A A A下的样本 D D D的基尼指数为:
G i n i ( D , A ) = ∣ D 1 ∣ D G i n i ( D 1 ) + ∣ D 2 ∣ D G i n i ( D 2 ) Gini(D,A)=\frac{|D_1|}{D}Gini(D_1)+\frac{|D_2|}{D}Gini(D_2) Gini(D,A)=DD1Gini(D1)+DD2Gini(D2)
G i n i ( D ) Gini(D) Gini(D)表示集合 D D D的不确定性, G i n i ( D , A ) Gini(D,A) Gini(D,A)表示经 A = a A=a A=a分割后集合 D D D的不确定性。

2 决策树的构建(ID3)

关键点:

  • 上一章提到熵越大,即信息量越大,值变化越多,越不确定,越不容易被预测,那么想要成功预测,需要尽量将熵值降低,降得越快越好
  • 构建决策树的关键时选取节点,就应该选取熵越小的决策属性作为节点,这样让熵下降得最快,就能够得到最优得决策树。

有如下的天气数据,根据这些数据构建一个决策树模型,预测要不要打球

outlooktemperaturehumiditywindyplay
sunnyhothighFALSEno
sunnyhothighTRUEno
overcasthothighFALSEyes
rainymildhighFALSEyes
rainycoolnormalFALSEyes
rainycoolnormalTRUEno
overcastcoolnormalTRUEyes
sunnymildhighFALSEno
sunnycoolnormalFALSEyes
rainymildnormalFALSEyes
sunnymildnormalTRUEyes
overcastmildhighTRUEyes
overcasthotnormalFALSEyes
rainymildhighTRUEno

2.1 计算系统固有熵

在没有任何特征划分的情况下计算数据的固有的熵值,其中打球的概率为 9 14 \frac{9}{14} 149,不打球的概率为 5 14 \frac{5}{14} 145,其熵为
H ( X ) = − ∑ i = 1 n p i l o g p i = − 9 14 l o g 9 14 − 5 14 l o g 5 14 ≈ 0.652 H(X)=-\sum_{i=1}^{n}{p_ilogp_i}=-\frac{9}{14}log\frac{9}{14} - \frac{5}{14}log\frac{5}{14}\approx 0.652 H(X)=i=1npilogpi=149log149145log1450.652

2.2 计算分支熵

假设选取outlook作为根节点,其取值有三个,分别时:sunny、rainy和overcast。其中为sunny时打球的概率时 2 5 \frac{2}{5} 52,不打球的概率为 3 5 \frac{3}{5} 53,其熵为 − 2 5 l o g 2 5 − 3 5 l o g 3 5 ≈ 0.673 -\frac{2}{5}log\frac{2}{5}-\frac{3}{5}log\frac{3}{5}\approx0.673 52log5253log530.673;当值为rainy时,都会打球,其熵为0;当值为overcast时,打球的概率时 3 5 \frac{3}{5} 53,不打球的概率为 2 5 \frac{2}{5} 52,其熵为 − 3 5 l o g 3 5 − 2 5 l o g 2 5 ≈ 0.673 -\frac{3}{5}log\frac{3}{5}-\frac{2}{5}log\frac{2}{5}\approx0.673 53log5352log520.673

2.3 计算总熵

那么如果outlook为根节点,sunny的概率为 5 14 \frac{5}{14} 145,rainy的概率为 4 14 \frac{4}{14} 144,overcast的概率为 5 14 \frac{5}{14} 145,则其总的熵为:
H = 5 14 × 0.673 + 4 14 × 0 + 5 14 × 0.673 ≈ 0.481 H = \frac{5}{14}\times0.673 + \frac{4}{14}\times0 + \frac{5}{14}\times0.673 \approx0.481 H=145×0.673+144×0+145×0.6730.481

2.4 计算信息增益

那么信息增益为: 0.652 − 0.481 = 0.171 0.652-0.481=0.171 0.6520.481=0.171
同理,以其他的决策属性为根节点计算其总熵,并计算信息增益,最后发现outlook的信息增益最大,因此该决策树选择outlook为根节点。

2.5 ID3的缺陷

  • ID3没有考虑到连续的情况;
  • ID3采用信息增益大的特征优先建立节点;
  • ID3未作缺失值的考虑;
  • ID3没有考虑过拟合的情况。

事实上,在相同条件下,取值比较多的特征比取值少的特征信息增益大,ID3算法可能选择无助决策的特征。极端的情况如下:

idoutlooktemperaturehumiditywindyplay
1sunnyhothighFALSEno
2sunnyhothighTRUEno
3overcasthothighFALSEyes
4rainymildhighFALSEyes
5rainycoolnormalFALSEyes
6rainycoolnormalTRUEno
7overcastcoolnormalTRUEyes
8sunnymildhighFALSEno
9sunnycoolnormalFALSEyes
10rainymildnormalFALSEyes
11sunnymildnormalTRUEyes
12overcastmildhighTRUEyes
13overcasthotnormalFALSEyes
14rainymildhighTRUEno

id作为一个决策特征,但id与去不去打球半毛钱关系都没有,而按照信息增益大的特征优先建立节点的原则,id会被当作根节点,无疑时非常荒唐的。

4 C4.5

4.1 对连续值的处理

将连续值离散化,如有 m m m个样本的连续特征 A A A m m m个,从大到小排列为 a 1 , a 2 , a 3 , . . . , a m a_1, a_2, a_3, ..., a_m a1,a2,a3,...,am,取相邻两样本值的中位数,一共取得 m − 1 m-1 m1个划分点。对于这 m − 1 m-1 m1个点,分别计算以该点作为二元分类点时的信息增益。选择信息增益最大的点作为该连续特征的二元离散分类点。

4.2 规避选择不相关特征

引入信息增益率
信 息 增 益 率 = 信 息 增 益 特 征 熵 信息增益率=\frac{信息增益}{特征熵} =
如上表,虽然信息增益非常大,但其信息熵也很大,因此信息增益率其实并不大,还没有到选择它作为根节点的程度。

4.3 缺失值处理

缺失值处理的问题,主要需要解决的是两个问题

  • 一是在特征值缺失的情况下进行划分特征的选择?(即如何计算特征的信息增益率);
  • 二是选定该划分特征,对于缺失该特征值的样本如何处理?(即到底把这个样本划分到哪个结点里)。

对于具有缺失值特征,用没有缺失的样本子集所占比重来折算;

对于第二个子问题,可以将缺失特征的样本同时划分入所有的子节点,不过将该样本的权重按各个子节点样本的数量比例来分配。比如缺失特征A的样本 a a a之前权重为 1 1 1,特征 A A A 3 3 3个特征值 A 1 A_1 A1 A 2 A_2 A2 A 3 A_3 A3 3 3 3个特征值对应的无缺失 A A A特征的样本个数为 2 2 2 3 3 3 4 4 4,则 a a a同时划分入 A 1 A_1 A1 A 2 A_2 A2 A 3 A_3 A3。对应权重调节为 2 9 \frac{2}{9} 92 3 9 \frac{3}{9} 93 4 9 \frac{4}{9} 94

4.4 过拟合处理

采用剪枝的方法,分为预剪枝和后剪枝两种。预剪枝是指在生成决策树的时候就决定是否剪枝。后剪枝指先生成决策树,再通过交叉验证来剪枝。

预剪枝策略:

  • 节点内数据样本低于某一阈值;
  • 所有节点特征都已分裂;
  • 节点划分前准确率比划分后准确率高。

后剪枝策略:

  • 采用的悲观剪枝方法,用递归的方式从低往上针对每一个非叶子节点,评估用一个最佳叶子节点去代替这棵子树是否有益。如果剪枝后与剪枝前相比其错误率是保持或者下降,则这棵子树就可以被替换掉。C4.5 通过训练数据集上的错误分类数量来估算未知样本上的错误率。后剪枝决策树的欠拟合风险很小,泛化性能往往优于预剪枝决策树。但同时其训练时间会长的多。

4.5 C4.5的不足

  • C4.5生成的是多叉树,即一个父节点可以有多个节点,而大多数时候在计算机中二叉树模型会比多叉树运算效率高;
  • C4.5只能用于分类;
  • C4.5由于使用了熵模型,里面有大量的耗时的对数运算,如果是连续值还有大量的排序运算,较为耗时。

5 CART算法

CART树的生成就是递归地构建二叉决策树的过程,对回归树用平方误差最小化准则,对分类树用基尼系数最小化准则,进行特征选择,生成二叉树。

5.1 CART分类树

ID3采用了信息增益思路选择特征,信息增益大的优先选择;C4.5采用信息增益率选择特征,改进了采用信息增益选择特征带来的错误。但无论时ID3或是C4.5,他们都采用了信息熵的模型,均会涉及大量的对数运算。

CART算法采用Gini系数代替了信息增益率,其中Gini系数代表了模型的不纯度,基尼系数越小,则不纯度越低,特征越好,正好与信息增益(率)是相反的。

另外CART分类树算法每次仅仅对某个特征的值进行二分,而不是多分,这样CART算法建立起来的是二叉树,而不是多叉树,进一步简化了计算。

有条件A将样本D切分为 D 1 D_1 D1 D 2 D_2 D2两个数据子集,其Gini增益为:
△ G i n i ( A ) = G i n i ( A ) − G i n i ( D , A ) = { 1 − ∑ k = 1 K ( ∣ C k ∣ ∣ d ∣ ) 2 } − { ∣ D 1 ∣ D G i n i ( D 1 ) + ∣ D 2 ∣ D G i n i ( D 2 ) } \triangle Gini(A)=Gini(A)-Gini(D,A)=\{1-\sum_{k=1}^{K}{(\frac{|C_k|}{|d|})^2}\}-\{\frac{|D_1|}{D}Gini(D_1)+\frac{|D_2|}{D}Gini(D_2)\} Gini(A)=Gini(A)Gini(D,A)={1k=1K(dCk)2}{DD1Gini(D1)+DD2Gini(D2)}
根据训练数据集,从根节点开始,递归对每一个节点进行如下操作,构建二叉决策树:

  • 1)设训练数据集为 D D D,计算现有特征对该数据集的基尼指数。此时对每一个特征 A A A,对其可能取得每个值 a a a,根据样本点对 A = a A=a A=a得测试为“是”或“否”将 D D D切分为 D 1 D_1 D1 D 2 D_2 D2两部分,即根据 A ≥ a A\geq a Aa A < a A<a A<a将样本分为两部分,并计算 G i n i ( D , A ) Gini(D,A) Gini(D,A)
  • 2)在所有的特征 A A A以及它们所有可能的切分点中,找出对应基尼指数最小 G i n i ( D , A ) Gini(D,A) Gini(D,A)的最优切分特征及取值,并依据最优特征和最优切分点,从现节点生成两个子节点,将训练数据集依特征分配到两个子节点中去,并判断是否满足停止条件
  • 3)递归调用1)2)
  • 4)生成CART决策树
    停止条件:节点中的样本个数小于预定阈值,或样本集的Gini系数小于预定阈值,或者没有更多的特征

5.2 CART回归树

假设 X X X Y Y Y分别为输入和输出,并且 Y Y Y是连续变量,给定训练集:
D = ( x 1 , y 1 ) , ( x 2 , y 2 ) , ( x 3 , y 3 ) , . . . , ( x n , y n ) D=(x_1, y_1), (x_2, y_2), (x_3, y_3), ..., (x_n, y_n) D=(x1,y1),(x2,y2),(x3,y3),...,(xn,yn)
一个回归树对应着输入空间(特征空间)的一个划分以及在划分的单元上的输出值。假设已将输入数据空间 X X X划分为 M M M个单元: R 1 , R 2 , R 3 , . . . , R M R_1, R_2, R_3, ..., R_M R1,R2,R3,...,RM,然后赋给每个输入空间的区域 R i R_i Ri有一个固定的代表输出值 c m c_m cm,则回归模型表示为:
f ( x ) = ∑ m = 1 M c m I ( x ∈ R m ) f(x)=\sum_{m=1}^{M}{c_mI}(x\in R_m) f(x)=m=1McmI(xRm)
当输入空间的划分确定后,可以使用平方误差 ∑ x i ∈ R m ( y i − f ( x i ) ) 2 \sum_{x_i\in R_m}(y_i-f(x_i))^2 xiRm(yif(xi))2来表示回归树对于训练数据的预测误差,用平方误差最小的准则求解每个单元上的最优输出值。易知单元 R m R_m Rm上的 c m c_m cm的最优值 c m ~ \tilde{c_m} cm~ R m R_m Rm上的所有输入实例 x i x_i xi对应的输出 y i y_i yi的均值,即:
c m ~ = a v e ( y i ∣ x i ∈ R m ) \tilde{c_m}=ave(y_i|x_i \in R_m) cm~=ave(yixiRm)
此处采用启发式方法对输入空间进行划分,选择第 j j j个变量 x ( j ) x^{(j)} x(j)和它的取值 s s s,作为切分变量和切分点,并定义两个区域:
R 1 ( j , s ) = { x ∣ x ( j ) ≤ s }    和    R 2 ( j , s ) = { x ∣ x ( j ) > s } R_1(j,s)=\{x|x^{(j)}\leq s\}\ \ 和\ \ R_2(j,s)=\{x|x^{(j)}\gt s\} R1(j,s)={xx(j)s}    R2(j,s)={xx(j)>s}
然后找到最优切分变量 j j j和最优切分点 s s s,具体来说就是求解:
m i n j , s [ m i n c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + m i n c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] min_{j,s}[min_{c_1}\sum_{x_i\in R_1(j,s)}(y_i-c_1)^2+min_{c_2}\sum_{x_i\in R_2(j,s)}(y_i-c_2)^2] minj,s[minc1xiR1(j,s)(yic1)2+minc2xiR2(j,s)(yic2)2]
对固定输入变量 j j j可以找到最优切分点 s s s
c 1 ~ = a v e ( y i ∣ x i ∈ R 1 ( j , s ) )    和    c 2 ~ = a v e ( y i ∣ x i ∈ R 2 ( j , s ) ) \tilde{c_1}=ave(y_i|x_i \in R_1(j,s))\ \ 和\ \ \tilde{c_2}=ave(y_i|x_i \in R_2(j,s)) c1~=ave(yixiR1(j,s))    c2~=ave(yixiR2(j,s))
遍历所有输入变量,找到最优的切分变量 j j j,构成一个对 ( j , s ) (j,s) (j,s),依次将输入空间划分为两个区域。接着对每个区域重复上述划分过程,直到满足停止条件为止。这样就生成了最小二乘回归树

5.3 剪枝

CART是基于代价复杂度的后剪枝策略,这种方法会生成一系列树,每个树都是通过将前面的树的某个或某些子树替换成一个叶节点而得到的,这一系列树中的最后一棵树仅含一个用来预测类别的叶节点。然后用一种成本复杂度的度量准则来判断哪棵子树应该被一个预测类别值的叶节点所代替。这种方法需要使用一个单独的测试数据集来评估所有的树,根据它们在测试数据集熵的分类性能选出最佳的树。

首先我们将最大树称为 T 0 T_0 T0,剪去一棵子树生成 T 1 T_1 T1 T 1 T_1 T1剪去一棵子树生成 T 2 T_2 T2,直到生成 T n T_n Tn。最后得到 T 0 T_0 T0 T n T_n Tn n + 1 n+1 n+1棵树,然后利用这n+1棵子树预测独立的验证数据集,谁的损失最小就选谁。

损失函数定义如下:
C a ( T ) = C ( T ) + α ∣ T ∣ C_a(T)=C(T)+\alpha|T| Ca(T)=C(T)+αT
其中 T T T为任意子树, C ( T ) C(T) C(T)为预测误差, ∣ T ∣ |T| T为子树 T T T的叶子节点的个数, α \alpha α为权重参数, C ( T ) C(T) C(T)衡量训练数据的拟合程度, ∣ T ∣ |T| T衡量树的复杂程度, α \alpha α权衡拟合程度与树的复杂度。

那么我们如何找到这个合适的 α \alpha α来使拟合程度与复杂度之间达到最好的平衡呢,最好的办法就是,我们将 α \alpha α 0 0 0取到正无穷,对于每一个固定的 α \alpha α,我们都可以找到使得 C α ( T ) C_\alpha(T) Cα(T)最小的最优子树 T ( α ) T(\alpha) T(α) 。当 α \alpha α很小的时候, T 0 T_0 T0是这样的最优子树,当 α \alpha α很大的时候,单独一个根节点是这样的最优的子树。

Breiman证明: α \alpha α从小增大, 0 = α 0 , α 1 , α 2 , . . . , α n < + ∞ 0=\alpha_0, \alpha_1, \alpha_2, ..., \alpha_n < +\infty 0=α0,α1,α2,...,αn<+, 在每一个区间 [ α i , α i + 1 ) [\alpha_i, \alpha_{i+1}) [αi,αi+1)中,子树 T i T_i Ti是这个区间中最优的。

每次剪枝剪的都是某个内部节点的子节点,也就是将某个内部节点的所有子节点回退到这个内部节点里,并将这个内部节点作为叶子节点。因此在计算整体的损失函数时,这个内部节点以外的值都没变,只有这个内部节点的局部损失函数改变了,因此原本需要计算全局的损失函数,但现在只需要计算内部节点剪枝前和剪枝后的损失函数。

对任意内部节点 t t t
剪枝前的状态:有 ∣ T t ∣ |T_t| Tt个叶子节点,预测误差是 C ( T t ) C(T_t) C(Tt)
剪枝后的状态:只有本身一个叶子节点,预测误差是 C ( t ) C(t) C(t)

因此剪枝前以 t t t节点为根节点的子树的损失为:
C a ( T t ) = C ( T t ) + α ∣ T t ∣ C_a(T_t)=C(T_t)+\alpha|T_t| Ca(Tt)=C(Tt)+αTt
剪枝后为:
C a ( T t ) = C ( T t ) + α C_a(T_t)=C(T_t)+\alpha Ca(Tt)=C(Tt)+α
容易得出,一定存在 α \alpha α,使得 C a ( T t ) = C a ( t ) C_a(T_t)=C_a(t) Ca(Tt)=Ca(t),这个值为:
α = C ( t ) − C ( T t ) ∣ T t ∣ − 1 \alpha=\frac{C(t)-C(T_t)}{|T_t| - 1} α=Tt1C(t)C(Tt)
对于当前节点 t t t,只要 α \alpha α大于这个值,一定存在 C a ( t ) < C a ( T t ) C_a(t)<C_a(T_t) Ca(t)<Ca(Tt),也就是剪掉这个节点比不剪掉更优,所有每一最优子树对应一个区间,在这个区间内部是最优的。

接着对 T i T_i Ti中的每一个节点都计算:
g ( t ) = C ( t ) − C ( T t ) ∣ T t ∣ − 1 g(t)=\frac{C(t)-C(T_t)}{|T_t|-1} g(t)=Tt1C(t)C(Tt)
g ( t ) g(t) g(t)表示整体损失函数减少的程度,最后剪掉 g ( t ) g(t) g(t)最小的 T t T_t Tt

算法步骤:

  • 输入:算法生成的决策树 T 0 T_0 T0
  • 输出:最优的决策树 T α T_\alpha Tα
  1. k = 0 k=0 k=0, T = T 0 T=T_0 T=T0
  2. α = ∞ \alpha=\infty α=
  3. 自上而下的对内部节点计算 C ( T t ) C(T_t) C(Tt), ∣ T t ∣ |T_t| Tt,以及:
    g ( t ) = C ( t ) − C ( T t ) ∣ T t ∣ − 1 g(t)=\frac{C(t)-C(T_t)}{|T_t|-1} g(t)=Tt1C(t)C(Tt)
    α = m i n ( α , g ( t ) ) \alpha=min(\alpha, g(t)) α=min(α,g(t))
    此处 C ( T t ) C(T_t) C(Tt)是对训练数据的预测误差
  4. 自上而下的访问内部节点 t t t,如果有 g ( t ) = α g(t)=\alpha g(t)=α,进行剪枝,并对叶子节点 t t t以多数表决发决定其类,得到树 T T T
  5. k = k + 1 k=k+1 k=k+1 α k = α \alpha_k=\alpha αk=α T k = T T_k=T Tk=T
  6. 如果T不是由根节点单独构成的树,则回到步骤4
  7. 采用交叉验证法在子树序列 T 0 , T 1 , T 2 , . . . , T n T_0, T_1, T_2, ..., T_n T0,T1,T2,...,Tn中选取最优子树 T α T_\alpha Tα

代码

import csv
from collections import defaultdict
import pydotplus
import numpy as np


# Important part
class Tree:
    def __init__(self, value=None, trueBranch=None, falseBranch=None, results=None, col=-1, summary=None, data=None):
        self.value = value
        self.trueBranch = trueBranch
        self.falseBranch = falseBranch
        self.results = results
        self.col = col
        self.summary = summary
        self.data = data


def calculateDiffCount(datas):
    # 将输入的数据汇总(input dataSet)
    # return results Set{type1:type1Count,type2:type2Count ... typeN:typeNCount}

    results = {}
    for data in datas:
        # data[-1] means dataType
        if data[-1] not in results:
            results[data[-1]] = 1
        else:
            results[data[-1]] += 1
    return results


def gini(rows):
    # 计算gini值(Calculate GINI)

    length = len(rows)
    results = calculateDiffCount(rows)
    imp = 0.0
    for i in results:
        imp += results[i] / length * results[i] / length
    return 1 - imp


def splitDatas(rows, value, column):
    # 根据条件分离数据集(splitDatas by value,column)
    # return 2 part(list1,list2)

    list1 = []
    list2 = []
    if (isinstance(value, int) or isinstance(value, float)):  # for int and float type
        for row in rows:
            if (row[column] >= value):
                list1.append(row)
            else:
                list2.append(row)
    else:  # for String type
        for row in rows:
            if row[column] == value:
                list1.append(row)
            else:
                list2.append(row)

    return (list1, list2)


def buildDecisionTree(rows, evaluationFunction=gini):
    # 递归建立决策树,当gain = 0 时停止递归
    # bulid decision tree by recursive function
    # stop recursive function when gain = 0
    # return tree

    currentGain = evaluationFunction(rows)
    column_length = len(rows[0])
    rows_length = len(rows)
    best_gain = 0.0
    best_value = None
    best_set = None

    # choose the best gain
    for col in range(column_length - 1):
        col_value_set = set([x[col] for x in rows])
        for value in col_value_set:
            list1, list2 = splitDatas(rows, value, col)
            p = len(list1) / rows_length
            gain = currentGain - p * evaluationFunction(list1) - (1 - p) * evaluationFunction(list2)
            if gain > best_gain:
                best_gain = gain
                best_value = (col, value)
                best_set = (list1, list2)

    dcY = {'impurity': '%.3f' % currentGain, 'samples': '%d' % rows_length}

    # stop or not stop
    if best_gain > 0:
        trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
        falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
        return Tree(col=best_value[0], value=best_value[1], trueBranch=trueBranch, falseBranch=falseBranch, summary=dcY)
    else:
        return Tree(results=calculateDiffCount(rows), summary=dcY, data=rows)


def prune(tree, miniGain, evaluationFunction=gini):
    # 剪枝, when gain < mini Gain,合并(merge the trueBranch and the falseBranch)

    if tree.trueBranch.results == None: prune(tree.trueBranch, miniGain, evaluationFunction)
    if tree.falseBranch.results == None: prune(tree.falseBranch, miniGain, evaluationFunction)

    if tree.trueBranch.results != None and tree.falseBranch.results != None:
        len1 = len(tree.trueBranch.data)
        len2 = len(tree.falseBranch.data)
        len3 = len(tree.trueBranch.data + tree.falseBranch.data)
        p = float(len1) / (len1 + len2)
        gain = evaluationFunction(tree.trueBranch.data + tree.falseBranch.data) - p * evaluationFunction(
            tree.trueBranch.data) - (1 - p) * evaluationFunction(tree.falseBranch.data)
        if (gain < miniGain):
            tree.data = tree.trueBranch.data + tree.falseBranch.data
            tree.results = calculateDiffCount(tree.data)
            tree.trueBranch = None
            tree.falseBranch = None


def classify(data, tree):
    if tree.results != None:
        return tree.results
    else:
        branch = None
        v = data[tree.col]
        if isinstance(v, int) or isinstance(v, float):
            if v >= tree.value:
                branch = tree.trueBranch
            else:
                branch = tree.falseBranch
        else:
            if v == tree.value:
                branch = tree.trueBranch
            else:
                branch = tree.falseBranch
        return classify(data, branch)


#下面是辅助代码画出树
#Unimportant part
#plot tree and load data
def plot(decisionTree):
    """Plots the obtained decision tree. """

    def toString(decisionTree, indent=''):
        if decisionTree.results != None:  # leaf node
            return str(decisionTree.results)
        else:
            szCol = 'Column %s' % decisionTree.col
            if szCol in dcHeadings:
                szCol = dcHeadings[szCol]
            if isinstance(decisionTree.value, int) or isinstance(decisionTree.value, float):
                decision = '%s >= %s?' % (szCol, decisionTree.value)
            else:
                decision = '%s == %s?' % (szCol, decisionTree.value)
            trueBranch = indent + 'yes -> ' + toString(decisionTree.trueBranch, indent + '\t\t')
            falseBranch = indent + 'no  -> ' + toString(decisionTree.falseBranch, indent + '\t\t')
            return (decision + '\n' + trueBranch + '\n' + falseBranch)

    print(toString(decisionTree))


def dotgraph(decisionTree):
    global dcHeadings
    dcNodes = defaultdict(list)
    """Plots the obtained decision tree. """

    def toString(iSplit, decisionTree, bBranch, szParent="null", indent=''):
        if decisionTree.results != None:  # leaf node
            lsY = []
            for szX, n in decisionTree.results.items():
                lsY.append('%s:%d' % (szX, n))
            dcY = {"name": "%s" % ', '.join(lsY), "parent": szParent}
            dcSummary = decisionTree.summary
            dcNodes[iSplit].append(['leaf', dcY['name'], szParent, bBranch, dcSummary['impurity'],
                                    dcSummary['samples']])
            return dcY
        else:
            szCol = 'Column %s' % decisionTree.col
            if szCol in dcHeadings:
                szCol = dcHeadings[szCol]
            if isinstance(decisionTree.value, int) or isinstance(decisionTree.value, float):
                decision = '%s >= %s' % (szCol, decisionTree.value)
            else:
                decision = '%s == %s' % (szCol, decisionTree.value)
            trueBranch = toString(iSplit + 1, decisionTree.trueBranch, True, decision, indent + '\t\t')
            falseBranch = toString(iSplit + 1, decisionTree.falseBranch, False, decision, indent + '\t\t')
            dcSummary = decisionTree.summary
            dcNodes[iSplit].append([iSplit + 1, decision, szParent, bBranch, dcSummary['impurity'],
                                    dcSummary['samples']])
            return

    toString(0, decisionTree, None)
    lsDot = ['digraph Tree {',
             'node [shape=box, style="filled, rounded", color="black", fontname=helvetica] ;',
             'edge [fontname=helvetica] ;'
             ]
    i_node = 0
    dcParent = {}
    print(dcNodes)
    for nSplit, lsY in dcNodes.items():
        for lsX in lsY:
            iSplit, decision, szParent, bBranch, szImpurity, szSamples = lsX
            if type(iSplit) == int:
                szSplit = '%d-%s' % (iSplit, decision)
                dcParent[szSplit] = i_node
                lsDot.append('%d [label=<%s<br/>impurity %s<br/>samples %s>, fillcolor="#e5813900"] ;' % (i_node,
                                                                                                          decision.replace(
                                                                                                              '>=',
                                                                                                              '&ge;').replace(
                                                                                                              '?', ''),
                                                                                                          szImpurity,
                                                                                                          szSamples))
            else:
                lsDot.append('%d [label=<impurity %s<br/>samples %s<br/>class %s>, fillcolor="#e5813900"] ;' % (i_node,
                                                                                                                szImpurity,
                                                                                                                szSamples,
                                                                                                                decision))

            if szParent != 'null':
                if bBranch:
                    szAngle = '45'
                    szHeadLabel = 'True'
                else:
                    szAngle = '-45'
                    szHeadLabel = 'False'
                szSplit = '%d-%s' % (nSplit, szParent)
                print(dcParent)
                p_node = dcParent[szSplit]
                if nSplit == 1:
                    lsDot.append('%d -> %d [labeldistance=2.5, labelangle=%s, headlabel="%s"] ;' % (p_node,
                                                                                                    i_node, szAngle,
                                                                                                    szHeadLabel))
                else:
                    lsDot.append('%d -> %d ;' % (p_node, i_node))
            i_node += 1
    lsDot.append('}')
    dot_data = '\n'.join(lsDot)
    return dot_data

def loadCSV(file):
    """Loads a CSV file and converts all floats and ints into basic datatypes."""
    def convertTypes(s):
        s = s.strip()
        try:
            return float(s) if '.' in s else int(s)
        except ValueError:
            return s

    reader = csv.reader(open(file, 'rt'))
    dcHeader = {}
    if bHeader:
        lsHeader = next(reader)
        for i, szY in enumerate(lsHeader):
                szCol = 'Column %d' % i
                dcHeader[szCol] = str(szY)
    return dcHeader, [[convertTypes(item) for item in row] for row in reader]



bHeader = True
# the bigger example
# dcHeadings, trainingData = loadCSV('../data/DecisionTreeData.csv') # demo data from matlab
dcHeadings, trainingData = loadCSV('../data/fishiris.csv') # demo data from matlab
decisionTree = buildDecisionTree(trainingData, evaluationFunction=gini)

print('剪枝前:')
result = plot(decisionTree)
prune(decisionTree, 0.4) # notify, when a branch is pruned (one time in this example)
print('剪枝后:')
result = plot(decisionTree)
dot_data = dotgraph(decisionTree)
# graph = pydotplus.graph_from_dot_data(dot_data)
# graph.write_png("prune.png")
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yougwypf1991

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值