决策树(python实现)

参考资料

算法实现

导入包

  • 最后一行用于设置graphviz插件的环境变量
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
import sys
import os
import warnings
warnings.filterwarnings('ignore')
os.environ["PATH"] += os.pathsep + 'D:/Python/Graphviz/bin'

数据准备

  • 使用sklearn包中自带鸢尾花数据集。
# 载入鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target
X.shape

模型建立与调参

clf = DecisionTreeClassifier()
from sklearn.model_selection import GridSearchCV
param_grid = [{'criterion':['gini','entropy'],
              'splitter':['best', 'random'],
              'max_depth':np.arange(5,21,3),
              'class_weight':['balanced'],
              'max_leaf_nodes':np.arange(2,6,1)}]
grid_search = GridSearchCV(clf, param_grid, cv = 10,
                          scoring = 'f1_micro',
                          return_train_score = True)
grid_search.fit(X,np.ravel(y))

模型最优参

grid_search.best_params_

输出:

{'class_weight': 'balanced',
 'criterion': 'gini',
 'max_depth': 11,
 'max_leaf_nodes': 5,
 'splitter': 'random'}

各参数下模型评估

cvres = grid_search.cv_results_
for accuracy,params in zip(cvres["mean_test_score"],cvres["params"]):
    print("{:.2}".format(accuracy),params)

输出:

0.67 {'class_weight': 'balanced', 'criterion': 'gini', 'max_depth': 5, 'max_leaf_nodes': 2, 'splitter': 'best'}
0.66 {'class_weight': 'balanced', 'criterion': 'gini', 'max_depth': 5, 'max_leaf_nodes': 2, 'splitter': 'random'}
0.95 {'class_weight': 'balanced', 'criterion': 'gini', 'max_depth': 5, 'max_leaf_nodes': 3, 'splitter': 'best'}
0.79 {'class_weight': 'balanced', 'criterion': 'gini', 'max_depth': 5, 'max_leaf_nodes': 3, 'splitter': 'random'}
...

保存模型训练结果

final_model = grid_search.best_estimator_
# 保存模型
with open("iris.dot", 'w') as f:
    f = tree.export_graphviz(final_model, out_file=f)

决策树可视化

import pydotplus 
from IPython.display import Image  
dot_data = tree.export_graphviz(final_model, out_file=None, 
                         feature_names=iris.feature_names,  
                         class_names=iris.target_names,  
                         filled=True, rounded=True,  
                         special_characters=True)  
graph = pydotplus.graph_from_dot_data(dot_data)  
Image(graph.create_png()) 
#保存为pdf文件
#graph.write_pdf("DTtree.pdf") 

输出:
请添加图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

羽星_s

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值