sklearn入门-决策树及其可视化

建立一棵树:

1.导入需要的算法库和模块

from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

2 探索数据

wine = load_wine()
wine.data

wine.target

 

 假如是一个表,他是什么样子

import pandas as pd
pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis = 1)

wine.target_names

 3 分成测试集和训练集

Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data, wine.target, test_size = 0.3)

这里注意前面顺序,否则模型混乱

Xtrain.shape

 

Xtest.shape

4 训练模型和预测

clf = tree.DecisionTreeClassifier(criterion = "entropy",random_state = 10) #random_stat控制随机性,这样准确率就是固定的,随机模式是30
clf = clf.fit(Xtrain, Ytrain)
score = clf.score(Xtest,Ytest)
score

 

 5 可视化

前期环境准备 需要cmd执行以下命令行

conda install python-graphviz

添加环境变量:例如

C:\Users\linxid\Anaconda3\Library\bin\graphviz

 开始画树

feature_name = ['酒精','苹果酸','灰','灰的碱性','镁','总酚','类黄酮','非黄烷类酚类','花青素','颜色强度','色调','葡萄酒','脯氨酸']

import graphviz
dot_data = tree.export_graphviz(clf
                               ,feature_names = feature_name  
                               ,class_names = ['琴酒','雪梨','贝尔摩德']
                               ,filled = True  #颜色
                               ,rounded = True  #边框圆滑\
                            
                               )
graph = graphviz.Source(dot_data)
graph

 

 

 看哪些元素比较重要

[*zip(feature_name, clf.feature_importances_)]

 

在运行过程当中会发现运行的结果可能会不一样,这个划分数据集的方法有关。所以结果会不一样子

解决过拟合问题的方法,降维方法:

利用max_depth使用,用作树的“精修”。

max_features限制分枝时考虑率的特征个数,擦红果限制个数的特征都会被舍弃。和max_depth异曲同工。

min_impurity_decrease限制信息增益打的大小。

  • 确认最优的减枝
import matplotlib.pyplot as plt
test = []
for i in range(10):
    clf = tree.DecisionTreeClassifier(
                                   max_depth=i +1
                                  ,criterion = "entropy"
                                  ,random_state = 30 #random_stat控制随机性,这样准确率就是固定的,随机模式是30
                                  ,splitter="random" #控制随机性
                                 ) 
    clf = clf.fit(Xtrain, Ytrain)
    score = clf.score(Xtest,Ytest)
    test.append(score)
plt.plot(range(1,11), test, color = 'red', label = " max_depth")
plt.legend()
plt.show()

 所以取3的时候最好。

其实不用每次对一个这样的参数,有待更新

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值