决策树的概念及代码实现

定义:

树的一种,从根节点一一步步到叶子节点,所有的数据都在叶子节点上,决策树既可以用来分类也可以用来做回归。

在这里插入图片描述
决策树中的节点:根节点是决策的开始选择项,中间节点是选择的过程,叶子节点是选择的结果,每次增加一个节点,就相当于对数据切了一刀。
在这里插入图片描述
如果 对数曲线0-1,区间内横轴表示事情的概率,纵轴的绝对值表示熵的大小,这表明,概率越小的事件y值越大,就是熵越大,因为他的 不确定性越大,混乱度越高,概率越大的事件越趋近于1,熵接近于0,表示没有混乱度。表示这个事件是确定的。
我们当然选择熵大的作为根节点的分类,因为分类后更加明确,不会出现各种数据。
在这里插入图片描述
信息增益的概念:表示特征X(例如姓名,年龄等)使得类Y的不确定性减少的程度,这是什么意思呢,特征1作为节点使得我的熵值由10降低到 8,那么10-8的差值2就是熵值减少的成都 ,也就是不确定性减少的程度。如果某个特征的信息增益率最大,则表示经过这个特征后样本的不确定性更少,他就可以作为根节点,那第二个节点就是信息增益率第二个大的那个。

决策树构造实例:
首先需要明确决策树可以用来分类,也可以用来解决线性回归问题,甚至可以将其运用到逻辑回归,他的数据和逻辑回归是非常类似的。
其次需要知道熵的计算方程:H(x) = -∑ pi * logpi,i=1,2,…,n
在这里插入图片描述
对于这个数据。先计算最终的熵值。
在历史数据中(14天)有9天打球,5天不打球,所以熵应为:
-(9/14)log2(9/14)-(5/14)log2(5/14) = 0.94
四个特征逐一分析:
outlook = sunny时,熵值为0.971
outlook = overcast时,熵值为 0
outlook = rainy时,熵值 为 0.971

根据数据统计,outlook取值分别为sunny,overcast,rainy的概率分别为5/14,4/14,5/14最终熵值计算,5/14 0.971 + 4/140 +5/14*0.971 = 0.693
信息增益从0.940降低到了0.693,增益为0.247
从此可以算其他的。

这个方法叫做ID3算法,是存在一定问题的,上面算的都是信息增益,如果除以自身的熵值,才叫做信息增益率
对于连续数据,用二分法找到他的中间值,进行离散和分类

决策树剪枝策略:
为什么要剪枝:决策树拟合风险过大,理论上完全分得开数据
剪枝策略:预减枝,后剪枝。
预剪枝:边建立决策树边进行剪枝的操作,限制深度,叶子节点个数,叶子节点样本数和信息增益量。
后剪枝:当建立完角儿书后来进行剪枝操作。通过一定的衡量标准
(叶子节点越多,损失越大)

代码实现:

%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets.california_housing import fetch_california_housing
housing = fetch_california_housing()
print(housing.data)
#开启内部数据

数据结果

[[   8.3252       41.            6.98412698 ...,    2.55555556   37.88
  -122.23      ]
 [   8.3014       21.            6.23813708 ...,    2.10984183   37.86
  -122.22      ]
 [   7.2574       52.            8.28813559 ...,    2.80225989   37.85
  -122.24      ]
 ..., 
 [   1.7          17.            5.20554273 ...,    2.3256351    39.43
  -121.22      ]
 [   1.8672       18.            5.32951289 ...,    2.12320917   39.43
  -121.32      ]
 [   2.3886       16.            5.25471698 ...,    2.61698113   39.37
  -121.24      ]]

我们只用纬度和经度去看看对房价的影响
决策树参数有许多,具体看API,列出常用的 max_depth预剪枝,数据量少的时候可以不管这个值,如果模型样本量多,特征也多的情况下,可以尝试限制 min_samples_split,如果某节点的样本数少于这个值,则不会继续再尝试选择最优特征来进行划分,如果样本量不大,不需要管这个值,如果样本量非常大,则需要增大这个值。如果某个叶子节点的数小于这个设置值,就不会继续分裂了。

#导入树模块
from sklearn import tree
#构建决策树,预处理最大深度是2
dtr = tree.DecisionTreeRegressor(max_depth = 2)
#建立模型
dtr.fit(housing.data[:,[6,7]],housing.target)
#建立模型会对模型进行计算,按照信息增益或者信息增益率下降最大的值进行选择根节点的特征,下面各个节点的特征,我们就能找到切分的各个点。
#这样新的数据来了后可以根据自己的特征来预测Y值属于哪一类。

输出结果:

DecisionTreeRegressor(criterion='mse', max_depth=2, max_features=None,
           max_leaf_nodes=None, min_impurity_decrease=0.0,
           min_impurity_split=None, min_samples_leaf=1,
           min_samples_split=2, min_weight_fraction_leaf=0.0,
           presort=False, random_state=None, splitter='best')
#要可视化显示树的结构,需要安装graphviz     http://www.graphviz.org/Download..php
#安装完毕后配置环境变量
dot_data = \
    tree.export_graphviz(
        dtr,#树的名字
        out_file=None,
        feature_names = housing.feature_names[6:8],#树的名字,其他不用改
        filled = True,
        impurity = False,
        rounded = True
    )
    #最后会生成一个dot文件
#pip install pydotplus
#path = r"C:\Program Files (x86)\Graphviz2.38\bin"修改路径,给他个新路径,不然会报错
#修改Python2.7\Lib\site-packages\pydot.py
import pydotplus
graph = pydotplus.graph_from_dot_data(dot_data)
graph.get_nodes()[7].set_fillcolor("#FF2DD") #把文件加载进去,只需要改颜色,其他不用改。
from IPython.display import Image
Image(graph.create_png())

在这里插入图片描述

保存文件到本地

graph.write_png("dtr_white_background.png")

以上的过程都是建立决策树,并且可视化,真正的拟合如下

#建立训练数据
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(housing.data[:,[6,7]],housing.target,test_size = 0.33,random_state=42)
dtr = tree.DecisionTreeRegressor(random_state = 42)
#这个树有两种选择,一个是分类,一个是回归,我们选择回归
dtr.fit(x_train,y_train)
#这个里面预测可以选择不同的方法,有score,有predict等,算法不一样
dtr.score(x_test,y_test)

输出结果为:0.71043942928033132
决策树的所有方法
apply(X[, check_input]) 返回每个样本的叶节点的预测序号
decision_path(X[, check_input]) 返回决策树的决策路径 [n_samples, n_nodes]
fit(X, y[, sample_weight, check_input, …]) 从训练数据建立决策树,返回一个对象
fit_transform(X[, y]) 将数据X转换[n_samples, n_features_new]
get_params([deep]) 得到估计量的参数,返回一个映射
predict(X[, check_input]) 预测X的分类或者回归,返回[n_samples]
predict_log_proba(X) 预测输入样本的对数概率,返回[n_samples, n_classes]
predict_proba(X[, check_input]) 预测输入样本的属于各个类的概率[n_samples, n_classes]
score(X, y[, sample_weight]) 返回对于测试数据的平均准确率
set_params(**params) 设置估计量的参数
transform(*args, **kwargs) 将输入参数X减少的最重要的特征,返回[n_samples, n_selected_features]

作者:hainingwyx
链接:https://www.jianshu.com/p/59b510bafb4d
來源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

#有一个库非常重要,让机器自动帮我们选择最佳参数

from sklearn.grid_search import GridSearchCV #最好的方法就是让我们自己选择一堆参数进行遍历,看看哪个参数最好,这个库可以帮我们这一点
tree_param_grid = {'min_samples_split':list((3,6,9,12,14,15,17,18,20,22)),'max_depth':list((3,5,7,9,10,11,12,14,16,17))}#选取要循环的参数,就是要测试的参数
grid = GridSearchCV(tree.DecisionTreeRegressor(),param_grid=tree_param_grid,cv=5)#选择模型,选择CV,就是交叉验证,如果不进行
grid.fit(x_train,y_train)
grid.grid_scores_, grid.best_params_, grid.best_score_
#交叉验证,为了确定选择的参数是否准确,交差验证的原理是,把本身拿到的训练数据分成三份,分别叫做1,2,3,先把1和2建立模型,3作为测试数据,然后验证参数到底怎么样,然后2,3作为建立模型,看1作为测试数据,然后以此类推训练三次,这样做可以排除误差。不然如果验证集全是简单的数据会导致模型效果偏高或者偏低。

输出结果给出了最优的参数

([mean: 0.23863, std: 0.00609, params: {'max_depth': 3, 'min_samples_split': 3},
  mean: 0.23863, std: 0.00609, params: {'max_depth': 3, 'min_samples_split': 6},
  mean: 0.23863, std: 0.00609, params: {'max_depth': 3, 'min_samples_split': 9},
  mean: 0.23863, std: 0.00609, params: {'max_depth': 3, 'min_samples_split': 12},
  mean: 0.23863, std: 0.00609, params: {'max_depth': 3, 'min_samples_split': 14},
  mean: 0.23863, std: 0.00609, params: {'max_depth': 3, 'min_samples_split': 15},
  mean: 0.23863, std: 0.00609, params: {'max_depth': 3, 'min_samples_split': 17},
  mean: 0.23863, std: 0.00609, params: {'max_depth': 3, 'min_samples_split': 18},
  mean: 0.23863, std: 0.00609, params: {'max_depth': 3, 'min_samples_split': 20},
  mean: 0.23863, std: 0.00609, params: {'max_depth': 3, 'min_samples_split': 22},
  mean: 0.45076, std: 0.01250, params: {'max_depth': 5, 'min_samples_split': 3},
  mean: 0.45076, std: 0.01250, params: {'max_depth': 5, 'min_samples_split': 6},
  mean: 0.45076, std: 0.01250, params: {'max_depth': 5, 'min_samples_split': 9},
  mean: 0.45076, std: 0.01250, params: {'max_depth': 5, 'min_samples_split': 12},
  mean: 0.45076, std: 0.01250, params: {'max_depth': 5, 'min_samples_split': 14},
  mean: 0.45076, std: 0.01250, params: {'max_depth': 5, 'min_samples_split': 15},
  mean: 0.45076, std: 0.01250, params: {'max_depth': 5, 'min_samples_split': 17},
  mean: 0.45076, std: 0.01250, params: {'max_depth': 5, 'min_samples_split': 18},
  mean: 0.45076, std: 0.01250, params: {'max_depth': 5, 'min_samples_split': 20},
  mean: 0.45076, std: 0.01250, params: {'max_depth': 5, 'min_samples_split': 22},
  mean: 0.57266, std: 0.01323, params: {'max_depth': 7, 'min_samples_split': 3},
  mean: 0.57275, std: 0.01310, params: {'max_depth': 7, 'min_samples_split': 6},
  mean: 0.57201, std: 0.01238, params: {'max_depth': 7, 'min_samples_split': 9},
  mean: 0.57174, std: 0.01258, params: {'max_depth': 7, 'min_samples_split': 12},
  mean: 0.57168, std: 0.01251, params: {'max_depth': 7, 'min_samples_split': 14},
  mean: 0.57181, std: 0.01266, params: {'max_depth': 7, 'min_samples_split': 15},
  mean: 0.57183, std: 0.01264, params: {'max_depth': 7, 'min_samples_split': 17},
  mean: 0.57184, std: 0.01252, params: {'max_depth': 7, 'min_samples_split': 18},
  mean: 0.57169, std: 0.01266, params: {'max_depth': 7, 'min_samples_split': 20},
  mean: 0.57152, std: 0.01256, params: {'max_depth': 7, 'min_samples_split': 22},
  mean: 0.64256, std: 0.00583, params: {'max_depth': 9, 'min_samples_split': 3},
  mean: 0.64359, std: 0.00586, params: {'max_depth': 9, 'min_samples_split': 6},
  mean: 0.64353, std: 0.00532, params: {'max_depth': 9, 'min_samples_split': 9},
  mean: 0.64495, std: 0.00601, params: {'max_depth': 9, 'min_samples_split': 12},
  mean: 0.64546, std: 0.00613, params: {'max_depth': 9, 'min_samples_split': 14},
  mean: 0.64554, std: 0.00615, params: {'max_depth': 9, 'min_samples_split': 15},
  mean: 0.64540, std: 0.00560, params: {'max_depth': 9, 'min_samples_split': 17},
  mean: 0.64532, std: 0.00551, params: {'max_depth': 9, 'min_samples_split': 18},
  mean: 0.64483, std: 0.00568, params: {'max_depth': 9, 'min_samples_split': 20},
  mean: 0.64405, std: 0.00586, params: {'max_depth': 9, 'min_samples_split': 22},
  mean: 0.66494, std: 0.00609, params: {'max_depth': 10, 'min_samples_split': 3},
  mean: 0.66662, std: 0.00515, params: {'max_depth': 10, 'min_samples_split': 6},
  mean: 0.66730, std: 0.00596, params: {'max_depth': 10, 'min_samples_split': 9},
  mean: 0.66883, std: 0.00677, params: {'max_depth': 10, 'min_samples_split': 12},
  mean: 0.66889, std: 0.00648, params: {'max_depth': 10, 'min_samples_split': 14},
  mean: 0.66872, std: 0.00634, params: {'max_depth': 10, 'min_samples_split': 15},
  mean: 0.66856, std: 0.00557, params: {'max_depth': 10, 'min_samples_split': 17},
  mean: 0.66897, std: 0.00566, params: {'max_depth': 10, 'min_samples_split': 18},
  mean: 0.66829, std: 0.00585, params: {'max_depth': 10, 'min_samples_split': 20},
  mean: 0.66721, std: 0.00585, params: {'max_depth': 10, 'min_samples_split': 22},
  mean: 0.67943, std: 0.00430, params: {'max_depth': 11, 'min_samples_split': 3},
  mean: 0.68174, std: 0.00301, params: {'max_depth': 11, 'min_samples_split': 6},
  mean: 0.68388, std: 0.00254, params: {'max_depth': 11, 'min_samples_split': 9},
  mean: 0.68593, std: 0.00492, params: {'max_depth': 11, 'min_samples_split': 12},
  mean: 0.68591, std: 0.00430, params: {'max_depth': 11, 'min_samples_split': 14},
  mean: 0.68578, std: 0.00377, params: {'max_depth': 11, 'min_samples_split': 15},
  mean: 0.68552, std: 0.00357, params: {'max_depth': 11, 'min_samples_split': 17},
  mean: 0.68582, std: 0.00468, params: {'max_depth': 11, 'min_samples_split': 18},
  mean: 0.68554, std: 0.00453, params: {'max_depth': 11, 'min_samples_split': 20},
  mean: 0.68418, std: 0.00459, params: {'max_depth': 11, 'min_samples_split': 22},
  mean: 0.69061, std: 0.00740, params: {'max_depth': 12, 'min_samples_split': 3},
  mean: 0.69432, std: 0.00610, params: {'max_depth': 12, 'min_samples_split': 6},
  mean: 0.69615, std: 0.00637, params: {'max_depth': 12, 'min_samples_split': 9},
  mean: 0.69918, std: 0.00848, params: {'max_depth': 12, 'min_samples_split': 12},
  mean: 0.69952, std: 0.00798, params: {'max_depth': 12, 'min_samples_split': 14},
  mean: 0.69913, std: 0.00762, params: {'max_depth': 12, 'min_samples_split': 15},
  mean: 0.69854, std: 0.00743, params: {'max_depth': 12, 'min_samples_split': 17},
  mean: 0.69874, std: 0.00838, params: {'max_depth': 12, 'min_samples_split': 18},
  mean: 0.69838, std: 0.00742, params: {'max_depth': 12, 'min_samples_split': 20},
  mean: 0.69660, std: 0.00760, params: {'max_depth': 12, 'min_samples_split': 22},
  mean: 0.69220, std: 0.00973, params: {'max_depth': 14, 'min_samples_split': 3},
  mean: 0.70085, std: 0.00812, params: {'max_depth': 14, 'min_samples_split': 6},
  mean: 0.70571, std: 0.00718, params: {'max_depth': 14, 'min_samples_split': 9},
  mean: 0.70998, std: 0.00867, params: {'max_depth': 14, 'min_samples_split': 12},
  mean: 0.70973, std: 0.00891, params: {'max_depth': 14, 'min_samples_split': 14},
  mean: 0.70881, std: 0.00882, params: {'max_depth': 14, 'min_samples_split': 15},
  mean: 0.70772, std: 0.00738, params: {'max_depth': 14, 'min_samples_split': 17},
  mean: 0.70845, std: 0.00761, params: {'max_depth': 14, 'min_samples_split': 18},
  mean: 0.70807, std: 0.00577, params: {'max_depth': 14, 'min_samples_split': 20},
  mean: 0.70759, std: 0.00584, params: {'max_depth': 14, 'min_samples_split': 22},
  mean: 0.69730, std: 0.00830, params: {'max_depth': 16, 'min_samples_split': 3},
  mean: 0.70805, std: 0.00610, params: {'max_depth': 16, 'min_samples_split': 6},
  mean: 0.71735, std: 0.00577, params: {'max_depth': 16, 'min_samples_split': 9},
  mean: 0.71861, std: 0.00896, params: {'max_depth': 16, 'min_samples_split': 12},
  mean: 0.71796, std: 0.00855, params: {'max_depth': 16, 'min_samples_split': 14},
  mean: 0.71609, std: 0.00789, params: {'max_depth': 16, 'min_samples_split': 15},
  mean: 0.71457, std: 0.00562, params: {'max_depth': 16, 'min_samples_split': 17},
  mean: 0.71526, std: 0.00673, params: {'max_depth': 16, 'min_samples_split': 18},
  mean: 0.71610, std: 0.00699, params: {'max_depth': 16, 'min_samples_split': 20},
  mean: 0.71580, std: 0.00624, params: {'max_depth': 16, 'min_samples_split': 22},
  mean: 0.69742, std: 0.00819, params: {'max_depth': 17, 'min_samples_split': 3},
  mean: 0.70963, std: 0.00510, params: {'max_depth': 17, 'min_samples_split': 6},
  mean: 0.71986, std: 0.00579, params: {'max_depth': 17, 'min_samples_split': 9},
  mean: 0.72076, std: 0.00850, params: {'max_depth': 17, 'min_samples_split': 12},
  mean: 0.72019, std: 0.00806, params: {'max_depth': 17, 'min_samples_split': 14},
  mean: 0.71901, std: 0.00848, params: {'max_depth': 17, 'min_samples_split': 15},
  mean: 0.71713, std: 0.00600, params: {'max_depth': 17, 'min_samples_split': 17},
  mean: 0.71788, std: 0.00730, params: {'max_depth': 17, 'min_samples_split': 18},
  mean: 0.71894, std: 0.00798, params: {'max_depth': 17, 'min_samples_split': 20},
  mean: 0.71845, std: 0.00712, params: {'max_depth': 17, 'min_samples_split': 22}],
 {'max_depth': 17, 'min_samples_split': 12},
 0.7207555796872562)
#把上面的结果代入
dtr = tree.DecisionTreeRegressor(random_state = 42,max_depth=17,min_samples_split=12)
#这个树有两种选择,一个是分类,一个是回归,我们选择回归
dtr.fit(x_train,y_train)
#这个里面预测可以选择不同的方法,有score,有predict等,算法不一样
dtr.score(x_test,y_test)

输出结果为0.73

决策树的含义:
这个决策树进行回归是什么意思呢,就是把一堆X和Y建立训练数据,然后最终经过决策树决策,分类,最终的叶子节点有样本的个数和value值,也就是一堆样本按照决策树分了好多分,这一堆X,属于这个Y值,这样的话效率是很低的,需要森林。如果是决策树分类的话,就是这一堆x属于这个类

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值