可视化决策树之Python实现

决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。一些基础原理这里就不再一一介绍了,直接进入今天的主题,如何可视化决策树。

本篇使用klearn来实现决策树的过程,下面是详细讲解:

首先导入必要的包:

 

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score

然后,导入数据集。我用的是kaggle上的蘑菇数据集,这是一个经典的决策树数据集,非常适合决策树,下面我们就会知道。

 

 

data = pd.read_csv("mushrooms.csv")
data.head()

先初步认识一下数据集:

 

可以看出这是一个分类变量的数据集。然后,我们就要将它变成数值变量,好利于下面的建模。

 

from sklearn.preprocessing import LabelEncoder
labelencoder = LabelEncoder()
for col in data.columns:
    data[col] = labelencoder.fit_transform(data[col])
data.head()


之后,我们来看看数据的大小:

 

 

data.shape

(8124, 23)
数据准备后,我们开始提取训练集与测试集。

 

 

y = data['class']
X = data.drop('class', axis=1)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, train_size=0.8)
columns = X_train.columns

接着标准化训练集

 

 

# 数据标准化
from sklearn.preprocessing import StandardScaler
ss_X = StandardScaler()
ss_y = StandardScaler()
X_train = ss_X.fit_transform(X_train)
X_test = ss_X.transform(X_test)

接着,构建决策树模型

 

 

from sklearn.tree import DecisionTreeClassifier
model_tree = DecisionTreeClassifier()
model_tree.fit(X_train, y_train)

评价模型准确性

 

 

y_prob = model_tree.predict_proba(X_test)[:,1]
y_pred = np.where(y_prob > 0.5, 1, 0)
model_tree.score(X_test, y_pred)

可以得到结果:1. 

 

说明决策树非常吻合此数据集。

最后,完成决策树的可视化

 

# 可视化树图
data_ = pd.read_csv("mushrooms.csv")
data_feature_name = data_.columns[1:]
data_target_name = np.unique(data_["class"])
import graphviz
import pydotplus
from sklearn import tree
from IPython.display import Image
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
dot_tree = tree.export_graphviz(model_tree,out_file=None,feature_names=data_feature_name,class_names=data_target_name,filled=True, rounded=True,special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_tree)
img = Image(graph.create_png())
graph.write_png("out.png")

 

 

 

 

 

 

 

 

 

注意:graphviz包不仅需要使用pip install graphviz安装还需要单独安装。使用时,还需要引入graphviz绝对路径。

参考:http://scikit-learn.org/stable/modules/tree.html

graphviz-2.38.msi安装包下载:http://www.graphviz.org/Download_windows.php

数据集:http://download.csdn.net/download/llh_1178/10115766
 



 

  • 25
    点赞
  • 291
    收藏
    觉得还不错? 一键收藏
  • 37
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 37
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值