数据科学 案例7 决策树之电脑购买(代码)
9 决策树
import os
import pandas as pd
1、导入数据
data = pd.read_csv(r'./data/AllElectronics.csv',encoding='gbk', skipinitialspace=True) #, skipinitialspace=True不能少
data.head()
age | income | student | credit_rating | buys_computer | |
---|---|---|---|---|---|
0 | 1 | 3 | 0 | 1 | 0 |
1 | 1 | 3 | 0 | 2 | 0 |
2 | 2 | 3 | 0 | 1 | 1 |
3 | 3 | 2 | 0 | 1 | 1 |
4 | 3 | 1 | 1 | 1 | 1 |
target = data['buys_computer']
data = data.loc[:, 'age':'credit_rating']
data.head()
age | income | student | credit_rating | |
---|---|---|---|---|
0 | 1 | 3 | 0 | 1 |
1 | 1 | 3 | 0 | 2 |
2 | 2 | 3 | 0 | 1 |
3 | 3 | 2 | 0 | 1 |
4 | 3 | 1 | 1 | 1 |
2、CART算法(分类树)
2.1 建立CART模型
import sklearn.tree as tree
clf = tree.DecisionTreeClassifier(criterion='entropy', max_depth=5, min_samples_split=2, min_samples_leaf=1, random_state=12345) # 当前支持计算信息增益和GINI
clf.fit(data, target)
DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=5,
max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=False,
random_state=12345, splitter='best')
2.2 可视化
可以使用graphviz将树结构输出,在python中嵌入graphviz可参考:pygraphviz
Python决策树可视化:GraphViz’s executables not found的解决方法
- 可视化
使用dot文件进行决策树可视化需要安装一些工具:
- 第一步是安装graphviz。linux可以用apt-get或者yum的方法安装。如果是windows,就在官网下载msi文件安装。
无论是linux还是windows,装完后都要设置环境变量,将graphviz的bin目录加到PATH,
比如windows,将C:/Program Files (x86)/Graphviz2.38/bin/加入了PATH - 第二步是安装python插件graphviz: pip install graphviz
- 第三步是安装python插件pydotplus: pip install pydotplus
# get_ipython().magic('matplotlib inline')
tree.export_graphviz(clf, out_file='cart.dot')
import pydotplus
from IPython.display import Image
import sklearn.tree as tree
#此前设置的环境变量不好用,可以用以下方法,查看了环境变量,发现没有就会追加上。
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/' #注意修改你的路径
# In[18]:
dot_data = tree.export_graphviz(
clf,
out_file=None,
feature_names=data.columns,
max_depth=5,
class_names=['0','1'],
filled=True
)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())