1.什么是决策树?
决策树是一种对实例进行分类的树形结构,决策树由节点和有向边组成,节点分为内部节点和叶节点。其中每个内部节点表示一个属性或者特征,每个叶节点代表一种分类结果,每个分支代表对某一个属性或者特征的输出。
下图是一个简单的决策树实例:
2.决策树模型的两种解释
- 决策树与if-then规则:决策树的根节点到叶节点到每一条路径构建一条规则,路径上内部节点到特征对应着规则的条件,而叶节点到类对应着规则的结论。
- 决策树与条件概率分布:决策树还可以表示给定特征条件下类的条件概率分布,其定义在特征空间的一个划分。假设X为表示特征的随机变量,Y为表示类的随机变量,那么这个条件概率分布可以表示为P(X|Y),各叶节点上的条件概率往往偏向于某一个类,即属于某一个类的概率越大。
3.特征选择
3.1信息增益
3.1.1熵
在信息论与概率统计中,熵表示对随机变量不确定性的度量,假设X是一个去有限个值的离散随机变量,其概率分布为 :
则随机变量X的熵定义为:
3.1.2条件熵
H(Y|X)表示在已知随机变量X的条件下随机变量的不确定性,定义为X给定条件下Y的条件概率分布的熵对X的数学期望:
需要注意的是:当熵和条件熵中的概率由极大似然估计或者其他数据估计方法所得到时,所对应的熵和条件熵称为经验熵和条件经验熵。
3.1.3信息增益
增益表示在已知特征A的信息下,使得类Y的信息的不确定性减少的程度。特征A对训练数据集D的信息增益表示为g(D,A),定义为训练数据集D的经验熵H(D)与给定特征A的条件下D的条件经验熵H(D|A)之差,即:
一般来说,将熵与条件熵之差称为互信息,因此决策树中的信息增益等价于训练数据集中类与特征的互信息。
3.1.4信息增益算法
信息增益准则:对训练数据集D,计算其每个特征的信息增益,并且比较它们的大小,选择信息增益最大的特征。
首先简单介绍一下各个符号的含义:D表示训练的数据集,|D|表示训练数据集的样本个数,表示共有K个类,表示属于类的样本个数。特征A有n个不同的取值,根据特征A的取值划分n个子集,为的样本个数,记子集中属于类的样本的集合为,为的样本个数。
具体算法流程:
- 1)计算数据集D的经验熵H(D):
- 2)计算特征A对数据集D的经验条件熵H(D|A):
- 3)计算信息增益
3.2 信息增益比
使用信息增益作为划分训练数据集的特征,存在偏向于选择取值较多的特征的问题,在此基础上,我们选择信息增益比可以解决这一问题。
信息增益比表示为,定义为其信息增益与训练数据集D关于特征A的值的熵之比,即:
3.3 基尼指数
在分类问题中,假设有K个类,样本点属于第k类的概率为,则概率分布的基尼指数定义为:
若样本集合D根据特征A是否取某一可能值被分为和两部分,则在特征A的条件下,集合D的基尼指数定义为:
4.决策树的生成
4.1 ID3算法
4.2 C4.5
4.3 CART
分类树与回归树CART模型同样由特征选择、决策树生成和决策树剪枝组成,不过不同的一点是CART模型既可以用于分类,也可以用于回归。
4.3.1 回归树生成
4.3.2 分类树生成
5.决策树的剪枝
5.1 树的剪枝算法
5.2 CART剪枝
6.代码
import matplotlib
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_graphviz, plot_tree
from sklearn.model_selection import train_test_split
import pydotplus
from IPython.display import display, Image
import matplotlib.pyplot as plt
adult_data = pd.read_csv('DecisionTree.csv')
feature_columns = [u'workclass', u'education', u'marital-status', u'occupation', u'relationship', u'race', u'gender', u'native-country']
label_column = ['income']
features = adult_data[feature_columns]
label = adult_data[label_column]
features = pd.get_dummies(features)
clf = DecisionTreeClassifier(criterion='entropy', max_depth=3)
clf = clf.fit(features.values, label.values)
#采用自带的plot_tree,也可以用graphviz
plt.figure(figsize=(15,9))
plot_tree(clf, filled=True, feature_names=features.columns, class_names=['<=50k', '>50k'])
plt.show()