数据挖掘之决策树归纳算法

决策树归纳算法

作者:这次国际周老师讲的课非常的硬核,赶紧整理一下笔记压压惊。

1.Motivation

  • Basic idea: recursively partitioning the input space in training step and traverse the tree with test data point to predict
  • Classification problem setup:
    • training dataset
    • testing dataset
    • validation dataset
    • Model
  • Transparent method: a tree-like structure that emulate human’s decision making flow
    • Can be converted into decision rules
    • Similarity to association rules

中:

  • 基本思想:在训练步骤中递归分区输入空间,并用测试数据点遍历树进行预测

  • 分类问题设置:

    ①训练数据集

    ②测试数据集

    ③验证数据集

    ④模型

  • 透明方法:一种模仿人类决策流程的树状结构

    ①可以转换成决策规则

    ②与关联规则相似

2.Decision Tree Structure(决策树结构)

Decision Tree Structure

  • Node: attribute splitting
    • root, leaf, internal nodes
  • Branch: attribute condition testing
    • Binary or more

节点:属性拆分

  • 根,叶,内部节点

分支:属性条件测试

  • 二进制或更多
    在这里插入图片描述

3.Framework of Supervised Learning(监督学习的框架)

  • Induction: model building from training data
    • Specific -> General
  • Deduction: model prediction on testing data
    • General -> Specific
  • Eager vs. lazy learning: presence of induction step

中文:

  • 归纳:从训练数据建立模型
    • 具体 - >一般
  • 扣除:对测试数据的模型预测
    • 一般 - >具体
  • 渴望与懒惰的学习:归纳(induction)步骤的存在

4.Application

Major Application of Decision Tree Induction Algorithm

  • Improve business decision making and support in a lot of industries: finance, banking, insurance, healthcare, etc.
  • Enhance customer service levels
  • Knowledge management platform to facilitate easier knowledge findability

5.Algorithm Summary (Hunt’s Algorithm)

  • Goal: improve dataset purity by recursively splitting with attributes
  • Check if a dataset dT is pure: if yes, then label it as a leaf node; if not, continue
  • Choose the attribute and (in the case of numerical attributes) split points that maximize information gain to split the dataset
  • Keep splitting until one of stop conditions is met
    • when all the data points belong to the same class
    • when all the records have the same attribute values
    • Early termination: set by model parameters (e.g. minsplit, minbucket, maxdepth) that control pruning
  • Other algorithm: ID3, C4.5, C5.0, CART

中:

  • 目标:通过递归分割属性来提高数据集纯度
  • 检查是否有数据集dT是纯的:如果是,则将其标记为叶节点; 如果没有,继续
  • 选择属性和(在数字属性的情况下)分割点,以最大化信息增益以分割数据集
  • 保持分裂直到满足一个停止条件
    • 当所有数据点属于同一个类时
    • 当所有记录具有相同的属性值时
    • 提前终止:由控制修剪的模型参数(例如minsplit,minbucket,maxdepth)设置
  • 其他算法:ID3,C4.5,C5.0,CART

6.Attributes for Decision Tree(决策树的属性)

  • Categorical attributes
    • Binary attributes: Classification And Regression Tree (CART) constructs binary trees
    • Multinomial attributes: grouping to reduce number of child nodes
  • Numerical attributes
    • Often discretized into binary attribute
    • Pick a splitting point (cutoff) on the attribute

中:

  • 分类属性
    • 二进制属性:分类和回归树(CART)构造二叉树
    • 多项属性:分组以减少子节点数
  • 数字属性
    • 经常被离散化为二进制属性
    • 在属性上选择一个分裂点(cut off)

7.Data Impurity Measure: Entropy(数据杂质度量:熵)

  • Entropy: property of a dataset D and the classification C

  • Entropy curve for binary classification
    在这里插入图片描述

8. Other Impurity Measure: Gini Index

8.1 Common characteristics of data impurity metric

  • Correlate with data purity with regards to targt class label
  • If data is more pure/homogeneous, metric has a lower value; if data is less pure/heterogeneous, metric has a higher value

8.2 Gini index

在这里插入图片描述

  • Special cases
  • Used in CART (Classification And Regression Trees)

9.Information Gain

  • Information gain: property of entropy (D, C) and attribute (A)

在这里插入图片描述

Adopted in ID3 algorithm

  • Gain ratio: Adjust information gain to control for number of groups after splitting

在这里插入图片描述

Adopted in C4.5 algorithm

10.Occam’s Razor

  • Smaller models are preferred given similar training accuracy
  • The complexity of a decision tree is defined as the number of splits in the tree
  • Pruning: reduce the size of the decision tree
    • Prepruning: halt tree construction early; requires setting threshold to stop attributes splitting
    • Postpruning: remove branches from a “fully grown” tree

11.Overfitting

  • Training accuracy vs.testing accuracy

在这里插入图片描述

12.Model Parameters

  • Set by rpart.control() function in rpart package.

    ​ rpart.control(

    ​ minsplit = 20,

    ​ minbucket = round(minsplit/3),

    ​ cp = 0.01,

    ​ maxdepth = 30,

    ​ …,

    ​ )

  • Minbucket: the minimum number of observations in any terminal node.

  • Minsplit: the minimum number of observations that must exist in a node in order for a split to be attempted.

  • Maxdepth: maximum depth of any node of the final tree, with the root node counted as depth 0.

  • Complexity parameter (cp = ): the improvement of model fit in order to create a new branch

    • When cp is set to a lower value, more complex the model can be; therefore increase cp to prune
    • Question: how to set cp for a fully grown tree (set to a negative value)
  • In order to avoid overfitting, we should increase minbucket, minsplit, or cp; or decrease maxdepth

13.Properties of the Algorithm

  • Greedy algorithm: top-down, recursive partitioning strategy
  • Rectlinear decision boundary (rectangles or hyper-rectangles)
  • Data fragmentation
  • Slow training process to build model, fast to predict
  • Robust to outliers
  • Non-parametric model: no underlying assumptions for the model
  • Output models either as a tree or as a set of rules (similar to association rules)

算法的属性

  • 贪心算法:自上而下,递归分区策略
  • 直线决策边界(矩形或超矩形)
  • 数据碎片
  • 缓慢的培训过程建立模型,快速预测
  • 对异常值的鲁棒性
  • 非参数模型:模型没有基本假设
  • 输出模型作为树或一组规则(类似于关联规则)

14.Demo

14.1 churn dataset from C50 package

# install.packages("C50")
> library(C50)
> data(churn)
> churn <- rbind(churnTrain, churnTest)
> str(churnTrain)
'data.frame':	3333 obs. of  20 variables:
 $ state                        : Factor w/ 51 levels "AK","AL","AR",..: 17 36 32 36 37 2 20 25 19 50 ...
 $ account_length               : int  128 107 137 84 75 118 121 147 117 141 ...
 $ area_code                    : Factor w/ 3 levels "area_code_408",..: 2 2 2 1 2 3 3 2 1 2 ...
 $ international_plan           : Factor w/ 2 levels "no","yes": 1 1 1 2 2 2 1 2 1 2 ...
 $ voice_mail_plan              : Factor w/ 2 levels "no","yes": 2 2 1 1 1 1 2 1 1 2 ...
 $ number_vmail_messages        : int  25 26 0 0 0 0 24 0 0 37 ...
 $ total_day_minutes            : num  265 162 243 299 167 ...
 $ total_day_calls              : int  110 123 114 71 113 98 88 79 97 84 ...
 $ total_day_charge             : num  45.1 27.5 41.4 50.9 28.3 ...
 $ total_eve_minutes            : num  197.4 195.5 121.2 61.9 148.3 ...
 $ total_eve_calls              : int  99 103 110 88 122 101 108 94 80 111 ...
 $ total_eve_charge             : num  16.78 16.62 10.3 5.26 12.61 ...
 $ total_night_minutes          : num  245 254 163 197 187 ...
 $ total_night_calls            : int  91 103 104 89 121 118 118 96 90 97 ...
 $ total_night_charge           : num  11.01 11.45 7.32 8.86 8.41 ...
 $ total_intl_minutes           : num  10 13.7 12.2 6.6 10.1 6.3 7.5 7.1 8.7 11.2 ...
 $ total_intl_calls             : int  3 3 5 7 3 6 7 6 4 5 ...
 $ total_intl_charge            : num  2.7 3.7 3.29 1.78 2.73 1.7 2.03 1.92 2.35 3.02 ...
 $ number_customer_service_calls: int  1 1 0 2 3 0 3 0 1 0 ...
 $ churn                        : Factor w/ 2 levels "yes","no": 2 2 2 2 2 2 2 2 2 2 ...

14.2 Model Training

library(caret)
> library(rpart)
> library(e1071)
> dt_model <- train(churn ~ ., data = churnTrain, metric = "Accuracy", method = "rpart")
> typeof(dt_model)
[1] "list"

> names(dt_model)
 [1] "method"       "modelInfo"    "modelType"    "results"      "pred"         "bestTune"    
 [7] "call"         "dots"         "metric"       "control"      "finalModel"  "preProcess"  
[13] "trainingData" "resample"     "resampledCM"  "perfNames"    "maximize"     "yLimits"     
[19] "times"        "levels"       "terms"        "coefnames"    "contrasts"    "xleves"     

14.3 Check Decision Tree Classifiers

> print(dt_model)
CART 

3333 samples
  19 predictor
   2 classes: 'yes', 'no' 

No pre-processing
Resampling: Bootstrapped (25 reps) 
Summary of sample sizes: 3333, 3333, 3333, 3333, 3333, 3333, ... 
Resampling results across tuning parameters:

  cp          Accuracy   Kappa    
  0.07867495  0.8741209  0.3072049
  0.08488613  0.8683224  0.2475440
  0.08902692  0.8653671  0.2178997

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.07867495.

14.4 Check Decision Tree Classifier Details

> print(dt_model$finalModel)
n= 3333 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 3333 483 no (0.1449145 0.8550855)  
  2) total_day_minutes>=264.45 211  84 yes (0.6018957 0.3981043)  
    4) voice_mail_planyes< 0.5 158  37 yes (0.7658228 0.2341772) *
    5) voice_mail_planyes>=0.5 53   6 no (0.1132075 0.8867925) *
  3) total_day_minutes< 264.45 3122 356 no (0.1140295 0.8859705) *

14.5 Model Prediction (1)

> dt_predict <- predict(
        dt_model, 
        newdata = churnTest, 
        na.action = na.omit, 
        type = "prob"
	)
> head(dt_predict, 5)
        yes        no
1 0.1140295 0.8859705
2 0.1140295 0.8859705
3 0.1132075 0.8867925
4 0.1140295 0.8859705
5 0.1140295 0.8859705

14.6 Model Prediction (2)

> dt_predict2 <- predict(
    dt_model, 
    newdata = churnTest, 
    type = "raw"
)
> head(dt_predict2)
[1] no no no no no no
Levels: yes no

14.7 Model Tuning (1)

> dt_model_tune <- train(
    churn ~ ., 
    data = churnTrain, 
    method = "rpart",                       
    metric = "Accuracy",
    tuneLength = 8
)
> print(dt_model_tune$finalModel)
n= 3333 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 3333 483 no (0.14491449 0.85508551)  
    2) total_day_minutes>=264.45 211  84 yes (0.60189573 0.39810427)  
      4) voice_mail_planyes< 0.5 158  37 yes (0.76582278 0.23417722)  
        8) total_eve_minutes>=187.75 101   5 yes (0.95049505 0.04950495) *
        9) total_eve_minutes< 187.75 57  25 no (0.43859649 0.56140351)  
         18) total_day_minutes>=277.7 32  11 yes (0.65625000 0.34375000)  
           36) total_eve_minutes>=144.35 24   4 yes (0.83333333 0.16666667) *
           37) total_eve_minutes< 144.35 8   1 no (0.12500000 0.87500000) *
         19) total_day_minutes< 277.7 25   4 no (0.16000000 0.84000000) *
      5) voice_mail_planyes>=0.5 53   6 no (0.11320755 0.88679245) *
    3) total_day_minutes< 264.45 3122 356 no (0.11402947 0.88597053)  
      6) number_customer_service_calls>=3.5 251 124 yes (0.50597610 0.49402390)  
       12) total_day_minutes< 160.2 102  13 yes (0.87254902 0.12745098) *
       13) total_day_minutes>=160.2 149  38 no (0.25503356 0.74496644)  
         26) total_eve_minutes< 141.75 19   5 yes (0.73684211 0.26315789) *
         27) total_eve_minutes>=141.75 130  24 no (0.18461538 0.81538462)  
           54) total_day_minutes< 175.75 34  14 no (0.41176471 0.58823529)  
            108) total_eve_minutes< 212.15 16   2 yes (0.87500000 0.12500000) *
            109) total_eve_minutes>=212.15 18   0 no (0.00000000 1.00000000) *
           55) total_day_minutes>=175.75 96  10 no (0.10416667 0.89583333) *
      7) number_customer_service_calls< 3.5 2871 229 no (0.07976315 0.92023685)  
       14) international_planyes>=0.5 267 101 no (0.37827715 0.62172285)  
         28) total_intl_calls< 2.5 51   0 yes (1.00000000 0.00000000) *
         29) total_intl_calls>=2.5 216  50 no (0.23148148 0.76851852)  
           58) total_intl_minutes>=13.1 43   0 yes (1.00000000 0.00000000) *
           59) total_intl_minutes< 13.1 173   7 no (0.04046243 0.95953757) *
       15) international_planyes< 0.5 2604 128 no (0.04915515 0.95084485)  
         30) total_day_minutes>=223.25 383  68 no (0.17754569 0.82245431)  
           60) total_eve_minutes>=259.8 51  17 yes (0.66666667 0.33333333)  
            120) voice_mail_planyes< 0.5 40   6 yes (0.85000000 0.15000000) *
            121) voice_mail_planyes>=0.5 11   0 no (0.00000000 1.00000000) *
           61) total_eve_minutes< 259.8 332  34 no (0.10240964 0.89759036) *
         31) total_day_minutes< 223.25 2221  60 no (0.02701486 0.97298514) *

14.8 Model Tuning (2)

> dt_model_tune2 <- train(
        churn ~ ., 
        data = churnTrain, 
        method = "rpart",
        tuneGrid = expand.grid(cp = seq(0, 0.1, 0.01))
	)
> print(dt_model_tune2$finalModel)
n= 3333 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 3333 483 no (0.14491449 0.85508551)  
    2) total_day_minutes>=264.45 211  84 yes (0.60189573 0.39810427)  
      4) voice_mail_planyes< 0.5 158  37 yes (0.76582278 0.23417722)  
        8) total_eve_minutes>=187.75 101   5 yes (0.95049505 0.04950495) *
        9) total_eve_minutes< 187.75 57  25 no (0.43859649 0.56140351)  
         18) total_day_minutes>=277.7 32  11 yes (0.65625000 0.34375000)  
           36) total_eve_minutes>=144.35 24   4 yes (0.83333333 0.16666667) *
           37) total_eve_minutes< 144.35 8   1 no (0.12500000 0.87500000) *
         19) total_day_minutes< 277.7 25   4 no (0.16000000 0.84000000) *
      5) voice_mail_planyes>=0.5 53   6 no (0.11320755 0.88679245) *
    3) total_day_minutes< 264.45 3122 356 no (0.11402947 0.88597053)  
      6) number_customer_service_calls>=3.5 251 124 yes (0.50597610 0.49402390)  
       12) total_day_minutes< 160.2 102  13 yes (0.87254902 0.12745098) *
       13) total_day_minutes>=160.2 149  38 no (0.25503356 0.74496644)  
         26) total_eve_minutes< 141.75 19   5 yes (0.73684211 0.26315789) *
         27) total_eve_minutes>=141.75 130  24 no (0.18461538 0.81538462)  
           54) total_day_minutes< 175.75 34  14 no (0.41176471 0.58823529)  
            108) total_eve_minutes< 212.15 16   2 yes (0.87500000 0.12500000) *
            109) total_eve_minutes>=212.15 18   0 no (0.00000000 1.00000000) *
           55) total_day_minutes>=175.75 96  10 no (0.10416667 0.89583333) *
      7) number_customer_service_calls< 3.5 2871 229 no (0.07976315 0.92023685)  
       14) international_planyes>=0.5 267 101 no (0.37827715 0.62172285)  
         28) total_intl_calls< 2.5 51   0 yes (1.00000000 0.00000000) *
         29) total_intl_calls>=2.5 216  50 no (0.23148148 0.76851852)  
           58) total_intl_minutes>=13.1 43   0 yes (1.00000000 0.00000000) *
           59) total_intl_minutes< 13.1 173   7 no (0.04046243 0.95953757) *
       15) international_planyes< 0.5 2604 128 no (0.04915515 0.95084485)  
         30) total_day_minutes>=223.25 383  68 no (0.17754569 0.82245431)  
           60) total_eve_minutes>=259.8 51  17 yes (0.66666667 0.33333333)  
            120) voice_mail_planyes< 0.5 40   6 yes (0.85000000 0.15000000) *
            121) voice_mail_planyes>=0.5 11   0 no (0.00000000 1.00000000) *
           61) total_eve_minutes< 259.8 332  34 no (0.10240964 0.89759036) *
         31) total_day_minutes< 223.25 2221  60 no (0.02701486 0.97298514) *

14.9 Model Pre-Pruning

> dt_model_preprune <- train(
        churn ~ ., 
        data = churnTrain, 
        method = "rpart",
        metric = "Accuracy",
        tuneLength = 8,
        control = rpart.control(
            minsplit = 50, 
            minbucket = 20, 
            maxdepth = 5
        )
	)
> print(dt_model_preprune$finalModel)
n= 3333 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 3333 483 no (0.14491449 0.85508551)  
   2) total_day_minutes>=264.45 211  84 yes (0.60189573 0.39810427)  
     4) voice_mail_planyes< 0.5 158  37 yes (0.76582278 0.23417722)  
       8) total_eve_minutes>=187.75 101   5 yes (0.95049505 0.04950495) *
       9) total_eve_minutes< 187.75 57  25 no (0.43859649 0.56140351)  
        18) total_day_minutes>=277.7 32  11 yes (0.65625000 0.34375000) *
        19) total_day_minutes< 277.7 25   4 no (0.16000000 0.84000000) *
     5) voice_mail_planyes>=0.5 53   6 no (0.11320755 0.88679245) *
   3) total_day_minutes< 264.45 3122 356 no (0.11402947 0.88597053)  
     6) number_customer_service_calls>=3.5 251 124 yes (0.50597610 0.49402390)  
      12) total_day_minutes< 160.2 102  13 yes (0.87254902 0.12745098) *
      13) total_day_minutes>=160.2 149  38 no (0.25503356 0.74496644)  
        26) total_eve_minutes< 155.5 29  11 yes (0.62068966 0.37931034) *
        27) total_eve_minutes>=155.5 120  20 no (0.16666667 0.83333333) *
     7) number_customer_service_calls< 3.5 2871 229 no (0.07976315 0.92023685)  
      14) international_planyes>=0.5 267 101 no (0.37827715 0.62172285)  
        28) total_intl_calls< 2.5 51   0 yes (1.00000000 0.00000000) *
        29) total_intl_calls>=2.5 216  50 no (0.23148148 0.76851852)  
          58) total_intl_minutes>=13.1 43   0 yes (1.00000000 0.00000000) *
          59) total_intl_minutes< 13.1 173   7 no (0.04046243 0.95953757) *
      15) international_planyes< 0.5 2604 128 no (0.04915515 0.95084485)  
        30) total_day_minutes>=223.25 383  68 no (0.17754569 0.82245431)  
          60) total_eve_minutes>=259.8 51  17 yes (0.66666667 0.33333333) *
          61) total_eve_minutes< 259.8 332  34 no (0.10240964 0.89759036) *
        31) total_day_minutes< 223.25 2221  60 no (0.02701486 0.97298514) *

14.10 Model Post-pruning

> dt_model_postprune <- prune(dt_model$finalModel, cp = 0.2)
> print(dt_model_postprune)
n= 3333 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 3333 483 no (0.1449145 0.8550855) *

14.11 Check Decision Tree Classifier (1)

> plot(dt_model$finalModel)
> text(dt_model$finalModel)

在这里插入图片描述

14.12 Check Decision Tree Classifier (2)

> library(rattle)
> fancyRpartPlot(dt_model$finalModel)

在这里插入图片描述

15.other

一些参考资料:

https://blog.csdn.net/baimafujinji/article/details/50467970

https://blog.csdn.net/baimafujinji/article/details/51724371

https://www.cnblogs.com/csguo/p/7814855.html

https://blog.csdn.net/yangzhongblog/article/details/47151837

https://wenku.baidu.com/view/e42ee971c950ad02de80d4d8d15abe23482f039a.html

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值