参考资料
- 关于理论推理部分参考刘建平老师的博客决策树算法原理
- python实现以及参数含义参考scikit-learn决策树算法类库使用小结
算法实现
导入包
- 最后一行用于设置graphviz插件的环境变量
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
import sys
import os
import warnings
warnings.filterwarnings('ignore')
os.environ["PATH"] += os.pathsep + 'D:/Python/Graphviz/bin'
数据准备
- 使用
sklearn
包中自带鸢尾花数据集。
# 载入鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target
X.shape
模型建立与调参
clf = DecisionTreeClassifier()
from sklearn.model_selection import GridSearchCV
param_grid = [{'criterion':['gini','entropy'],
'splitter':['best', 'random'],
'max_depth':np.arange(5,21,3),
'class_weight':['balanced'],
'max_leaf_nodes':np.arange(2,6,1)}]
grid_search = GridSearchCV(clf, param_grid, cv = 10,
scoring = 'f1_micro',
return_train_score = True)
grid_search.fit(X,np.ravel(y))
模型最优参
grid_search.best_params_
输出:
{'class_weight': 'balanced',
'criterion': 'gini',
'max_depth': 11,
'max_leaf_nodes': 5,
'splitter': 'random'}
各参数下模型评估
cvres = grid_search.cv_results_
for accuracy,params in zip(cvres["mean_test_score"],cvres["params"]):
print("{:.2}".format(accuracy),params)
输出:
0.67 {'class_weight': 'balanced', 'criterion': 'gini', 'max_depth': 5, 'max_leaf_nodes': 2, 'splitter': 'best'}
0.66 {'class_weight': 'balanced', 'criterion': 'gini', 'max_depth': 5, 'max_leaf_nodes': 2, 'splitter': 'random'}
0.95 {'class_weight': 'balanced', 'criterion': 'gini', 'max_depth': 5, 'max_leaf_nodes': 3, 'splitter': 'best'}
0.79 {'class_weight': 'balanced', 'criterion': 'gini', 'max_depth': 5, 'max_leaf_nodes': 3, 'splitter': 'random'}
...
保存模型训练结果
final_model = grid_search.best_estimator_
# 保存模型
with open("iris.dot", 'w') as f:
f = tree.export_graphviz(final_model, out_file=f)
决策树可视化
import pydotplus
from IPython.display import Image
dot_data = tree.export_graphviz(final_model, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
#保存为pdf文件
#graph.write_pdf("DTtree.pdf")
输出: