决策树(Decision Tree,DT),是树模型系列的根基模型。后续的随机森林(RF)、提升树(Boosting Tree)、梯度提升树(GBDT)、XGBoost都是在其基础上演化而来。
决策树及其演化模型(RF、GBDT、XGBoost)在数据挖掘、推荐系统、金融风控、计算广告、智能营销等领域得到广泛应用,是机器学习最基础模型之一,应该必须掌握。
本文以PPT的形式,首先回顾最优码、信息熵等决策树背后的数学逻辑(概率论、信息论);接着,介绍ID3树、C4.5树,CART树的原理;然后,介绍使用Sklearn实现ID3树、C4.5树,CART树的可视化;最后,对决策树优缺点、对比、演化进行总结。本文主要目录如下:
一、决策树基础
1、最优码
2、信息熵
3、分类、回归 二、决策树原理
1、决策树原理
2、CART决策树
3、ID3决策树
4、C4.5决策树 三、决策树可视化
1、ID3分类树可视化
2、CART分类树可视化
3、CART回归树可视化 四、决策树总结
1、决策树的优缺点
2、决策树 VS 线性回归
3、决策树的演化
直接上PPT。
一、决策树基础
1、最优码
信息熵的公式离不开Kraft不等式定理、最优码。
所以首先要搞清楚:什么是Kraft不等式定理?什么是最优码?
2、信息熵
什么是熵?什么是信息熵?
3、分类、回归
机器学习世界两个基本任务:分类任务、回归任务。
二、决策树原理
1、决策树原理
什么是决策树?
用一系列分支语句表示的模型。
所以决策树的本质是:分支语句。
决策树有一个很强的假设:信息是可分的,否则无法进行特征分支。
2、CART决策树
什么是分类与回归树(Classification and regression tree,CART)?
CART回归树算法如下:
举个CART回归树的例子如下:
例1的CART回归树的第1次切分、第2次切分:
例1的CART回归树的第3次切分、第4次切分:
例1的第5次切分,得到最终CART回归树:
例1的CART回归树的可视化如下:
例1的CART回归树的Python实现如下:
例1的CART回归树的Python实现代码如下:
"""
@author: 刘启林
@des:基于 sklearn 实现 CART 回归树的可视化
"""
import numpy as np
from sklearn.tree import DecisionTreeRegressor
X = np.arange(1, 11).reshape(-1, 1)
y = np.array([5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05])
tree = DecisionTreeRegressor(max_depth=4).fit(X, y)
# import os
# os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
from sklearn.externals.six import StringIO
from IPython.display import Image
from sklearn.tree import export_graphviz
import pydotplus
dot_data = export_graphviz(tree, out_file=None,
filled=True, rounded=True,
special_characters=True,
precision=2)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
3、ID3决策树
ID3的特征选择标准:信息增益。
4、C4.5决策树
三、决策树可视化
1、ID3分类树可视化
ID3分类树可视化代码如下:
"""
@author: 刘启林
@des:基于Sklearn的ID3分类树可视化
"""
import pandas as pd
import sklearn.datasets as datasets
from sklearn.tree import DecisionTreeClassifier
iris = datasets.load_iris()
df = pd.DataFrame(iris.data, columns = iris.feature_names)
y = iris.target
# ID3分类树,信息增益特征选择
dtree = DecisionTreeClassifier(criterion='entropy', max_depth=3).fit(df, y)
from sklearn.externals.six import StringIO
from IPython.display import Image
from sklearn.tree import export_graphviz
import pydotplus
dot_data = StringIO()
export_graphviz(dtree, out_file = dot_data, filled = True, rounded = True,
special_characters = True, precision=2)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
2、CART分类树可视化
CART分类树可视化的代码如下:
"""
@author: 刘启林
@des:基于Sklearn的CART分类树可视化
"""
import pandas as pd
import sklearn.datasets as datasets
from sklearn.tree import DecisionTreeClassifier
iris = datasets.load_iris()
df = pd.DataFrame(iris.data, columns = iris.feature_names)
y = iris.target
# CART分类树,基尼系数特征选择
dtree = DecisionTreeClassifier(criterion='gini').fit(df, y)
from sklearn.externals.six import StringIO
from IPython.display import Image
from sklearn.tree import export_graphviz
import pydotplus
dot_data = StringIO()
export_graphviz(dtree, out_file = dot_data, filled = True, rounded = True,
special_characters = True, precision=2)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
3、CART回归树可视化
CART回归树可视化的代码如下:
"""
@author: 刘启林
@des:基于 sklearn 实现 CART 回归树的可视化
"""
import numpy as np
from sklearn.tree import DecisionTreeRegressor
X = np.arange(1, 11).reshape(-1, 1)
y = np.array([5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05])
tree = DecisionTreeRegressor(max_depth=4).fit(X, y)
from sklearn.externals.six import StringIO
from IPython.display import Image
from sklearn.tree import export_graphviz
import pydotplus
dot_data = export_graphviz(tree, out_file=None,
filled=True, rounded=True,
special_characters=True,
precision=2)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
四、决策树总结
1、决策树的优缺点
2、决策树 VS 线性回归
3、决策树的演化
决策树演化发展成GBDT、XGBoost模型,在工业界有很广泛的应用。
更多GBDT的内容,可参考:
刘启林:GBDT的原理、公式推导、Python实现、可视化和应用zhuanlan.zhihu.com更多XGBoost的内容,可参考:
刘启林:XGBoost的原理、公式推导、Python实现和应用zhuanlan.zhihu.com结束语:
决策树的数学原理清晰,可解释强,运行快,在数据挖掘、推荐系统、金融风控、计算广告、智能营销等领域得到广泛应用,你值得掌握。
参考文献:
1、都有为等,物理学大词典[M],科学出版社,2017.12
2、Thomas M. Cover,信息论基础[M],机械工业出版社,2007.11
3、https://developers.google.com/machine-learning/glossary
4、李航, 统计学习方法(第2版)[M], 清华大学出版社, 2019.05