import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import time
决策树(红酒数据集为例)
1:Criterion
- 将数据表格转化为一棵树,需要找到最佳节点和最佳分支方法,对分类树来说,衡量这个最佳的直标叫做不纯度
- Criterion这个参数正是用来决定不纯度的计算方法的
- entropy 信息熵
- gini 基尼系数
- 区别 :比起基尼系数,信息熵对不纯度更加敏感,对不纯度的惩罚最强。但是在实际使用中,信息熵和基尼系数的效果基 本相同。信息熵的计算比基尼系数缓慢一些,因为基尼系数的计算不涉及对数
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
wine=load_wine()
# dataset 提供数据集默认是data,target分开的,连接查看结构使用的方法为pd.concat
pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 0 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 14.23 | 1.71 | 2.43 | 15.6 | 127.0 | 2.80 | 3.06 | 0.28 | 2.29 | 5.640000 | 1.04 | 3.92 | 1065.0 | 0 |
1 | 13.20 | 1.78 | 2.14 | 11.2 | 100.0 | 2.65 | 2.76 | 0.26 | 1.28 | 4.380000 | 1.05 | 3.40 | 1050.0 | 0 |
2 | 13.16 | 2.36 | 2.67 | 18.6 | 101.0 | 2.80 | 3.24 | 0.30 | 2.81 | 5.680000 | 1.03 | 3.17 | 1185.0 | 0 |
3 | 14.37 | 1.95 | 2.50 | 16.8 | 113.0 | 3.85 | 3.49 | 0.24 | 2.18 | 7.800000 | 0.86 | 3.45 | 1480.0 | 0 |
4 | 13.24 | 2.59 | 2.87 | 21.0 | 118.0 | 2.80 | 2.69 | 0.39 | 1.82 | 4.320000 | 1.04 | 2.93 | 735.0 | 0 |
5 | 14.20 | 1.76 | 2.45 | 15.2 | 112.0 | 3.27 | 3.39 | 0.34 | 1.97 | 6.750000 | 1.05 | 2.85 | 1450.0 | 0 |
6 | 14.39 | 1.87 | 2.45 | 14.6 | 96.0 | 2.50 | 2.52 | 0.30 | 1.98 | 5.250000 | 1.02 | 3.58 | 1290.0 | 0 |
7 | 14.06 | 2.15 | 2.61 | 17.6 | 121.0 | 2.60 | 2.51 | 0.31 | 1.25 | 5.050000 | 1.06 | 3.58 | 1295.0 | 0 |
8 | 14.83 | 1.64 | 2.17 | 14.0 | 97.0 | 2.80 | 2.98 | 0.29 | 1.98 | 5.200000 | 1.08 | 2.85 | 1045.0 | 0 |
9 | 13.86 | 1.35 | 2.27 | 16.0 | 98.0 | 2.98 | 3.15 | 0.22 | 1.85 | 7.220000 | 1.01 | 3.55 | 1045.0 | 0 |
10 | 14.10 | 2.16 | 2.30 | 18.0 | 105.0 | 2.95 | 3.32 | 0.22 | 2.38 | 5.750000 | 1.25 | 3.17 | 1510.0 | 0 |
11 | 14.12 | 1.48 | 2.32 | 16.8 | 95.0 | 2.20 | 2.43 | 0.26 | 1.57 | 5.000000 | 1.17 | 2.82 | 1280.0 | 0 |
12 | 13.75 | 1.73 | 2.41 | 16.0 | 89.0 | 2.60 | 2.76 | 0.29 | 1.81 | 5.600000 | 1.15 | 2.90 | 1320.0 | 0 |
13 | 14.75 | 1.73 | 2.39 | 11.4 | 91.0 | 3.10 | 3.69 | 0.43 | 2.81 | 5.400000 | 1.25 | 2.73 | 1150.0 | 0 |
14 | 14.38 | 1.87 | 2.38 | 12.0 | 102.0 | 3.30 | 3.64 | 0.29 | 2.96 | 7.500000 | 1.20 | 3.00 | 1547.0 | 0 |
15 | 13.63 | 1.81 | 2.70 | 17.2 | 112.0 | 2.85 | 2.91 | 0.30 | 1.46 | 7.300000 | 1.28 | 2.88 | 1310.0 | 0 |
16 | 14.30 | 1.92 | 2.72 | 20.0 | 120.0 | 2.80 | 3.14 | 0.33 | 1.97 | 6.200000 | 1.07 | 2.65 | 1280.0 | 0 |
17 | 13.83 | 1.57 | 2.62 | 20.0 | 115.0 | 2.95 | 3.40 | 0.40 | 1.72 | 6.600000 | 1.13 | 2.57 | 1130.0 | 0 |
18 | 14.19 | 1.59 | 2.48 | 16.5 | 108.0 | 3.30 | 3.93 | 0.32 | 1.86 | 8.700000 | 1.23 | 2.82 | 1680.0 | 0 |
19 | 13.64 | 3.10 | 2.56 | 15.2 | 116.0 | 2.70 | 3.03 | 0.17 | 1.66 | 5.100000 | 0.96 | 3.36 | 845.0 | 0 |
20 | 14.06 | 1.63 | 2.28 | 16.0 | 126.0 | 3.00 | 3.17 | 0.24 | 2.10 | 5.650000 | 1.09 | 3.71 | 780.0 | 0 |
21 | 12.93 | 3.80 | 2.65 | 18.6 | 102.0 | 2.41 | 2.41 | 0.25 | 1.98 | 4.500000 | 1.03 | 3.52 | 770.0 | 0 |
22 | 13.71 | 1.86 | 2.36 | 16.6 | 101.0 | 2.61 | 2.88 | 0.27 | 1.69 | 3.800000 | 1.11 | 4.00 | 1035.0 | 0 |
23 | 12.85 | 1.60 | 2.52 | 17.8 | 95.0 | 2.48 | 2.37 | 0.26 | 1.46 | 3.930000 | 1.09 | 3.63 | 1015.0 | 0 |
24 | 13.50 | 1.81 | 2.61 | 20.0 | 96.0 | 2.53 | 2.61 | 0.28 | 1.66 | 3.520000 | 1.12 | 3.82 | 845.0 | 0 |
25 | 13.05 | 2.05 | 3.22 | 25.0 | 124.0 | 2.63 | 2.68 | 0.47 | 1.92 | 3.580000 | 1.13 | 3.20 | 830.0 | 0 |
26 | 13.39 | 1.77 | 2.62 | 16.1 | 93.0 | 2.85 | 2.94 | 0.34 | 1.45 | 4.800000 | 0.92 | 3.22 | 1195.0 | 0 |
27 | 13.30 | 1.72 | 2.14 | 17.0 | 94.0 | 2.40 | 2.19 | 0.27 | 1.35 | 3.950000 | 1.02 | 2.77 | 1285.0 | 0 |
28 | 13.87 | 1.90 | 2.80 | 19.4 | 107.0 | 2.95 | 2.97 | 0.37 | 1.76 | 4.500000 | 1.25 | 3.40 | 915.0 | 0 |
29 | 14.02 | 1.68 | 2.21 | 16.0 | 96.0 | 2.65 | 2.33 | 0.26 | 1.98 | 4.700000 | 1.04 | 3.59 | 1035.0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
148 | 13.32 | 3.24 | 2.38 | 21.5 | 92.0 | 1.93 | 0.76 | 0.45 | 1.25 | 8.420000 | 0.55 | 1.62 | 650.0 | 2 |
149 | 13.08 | 3.90 | 2.36 | 21.5 | 113.0 | 1.41 | 1.39 | 0.34 | 1.14 | 9.400000 | 0.57 | 1.33 | 550.0 | 2 |
150 | 13.50 | 3.12 | 2.62 | 24.0 | 123.0 | 1.40 | 1.57 | 0.22 | 1.25 | 8.600000 | 0.59 | 1.30 | 500.0 | 2 |
151 | 12.79 | 2.67 | 2.48 | 22.0 | 112.0 | 1.48 | 1.36 | 0.24 | 1.26 | 10.800000 | 0.48 | 1.47 | 480.0 | 2 |
152 | 13.11 | 1.90 | 2.75 | 25.5 | 116.0 | 2.20 | 1.28 | 0.26 | 1.56 | 7.100000 | 0.61 | 1.33 | 425.0 | 2 |
153 | 13.23 | 3.30 | 2.28 | 18.5 | 98.0 | 1.80 | 0.83 | 0.61 | 1.87 | 10.520000 | 0.56 | 1.51 | 675.0 | 2 |
154 | 12.58 | 1.29 | 2.10 | 20.0 | 103.0 | 1.48 | 0.58 | 0.53 | 1.40 | 7.600000 | 0.58 | 1.55 | 640.0 | 2 |
155 | 13.17 | 5.19 | 2.32 | 22.0 | 93.0 | 1.74 | 0.63 | 0.61 | 1.55 | 7.900000 | 0.60 | 1.48 | 725.0 | 2 |
156 | 13.84 | 4.12 | 2.38 | 19.5 | 89.0 | 1.80 | 0.83 | 0.48 | 1.56 | 9.010000 | 0.57 | 1.64 | 480.0 | 2 |
157 | 12.45 | 3.03 | 2.64 | 27.0 | 97.0 | 1.90 | 0.58 | 0.63 | 1.14 | 7.500000 | 0.67 | 1.73 | 880.0 | 2 |
158 | 14.34 | 1.68 | 2.70 | 25.0 | 98.0 | 2.80 | 1.31 | 0.53 | 2.70 | 13.000000 | 0.57 | 1.96 | 660.0 | 2 |
159 | 13.48 | 1.67 | 2.64 | 22.5 | 89.0 | 2.60 | 1.10 | 0.52 | 2.29 | 11.750000 | 0.57 | 1.78 | 620.0 | 2 |
160 | 12.36 | 3.83 | 2.38 | 21.0 | 88.0 | 2.30 | 0.92 | 0.50 | 1.04 | 7.650000 | 0.56 | 1.58 | 520.0 | 2 |
161 | 13.69 | 3.26 | 2.54 | 20.0 | 107.0 | 1.83 | 0.56 | 0.50 | 0.80 | 5.880000 | 0.96 | 1.82 | 680.0 | 2 |
162 | 12.85 | 3.27 | 2.58 | 22.0 | 106.0 | 1.65 | 0.60 | 0.60 | 0.96 | 5.580000 | 0.87 | 2.11 | 570.0 | 2 |
163 | 12.96 | 3.45 | 2.35 | 18.5 | 106.0 | 1.39 | 0.70 | 0.40 | 0.94 | 5.280000 | 0.68 | 1.75 | 675.0 | 2 |
164 | 13.78 | 2.76 | 2.30 | 22.0 | 90.0 | 1.35 | 0.68 | 0.41 | 1.03 | 9.580000 | 0.70 | 1.68 | 615.0 | 2 |
165 | 13.73 | 4.36 | 2.26 | 22.5 | 88.0 | 1.28 | 0.47 | 0.52 | 1.15 | 6.620000 | 0.78 | 1.75 | 520.0 | 2 |
166 | 13.45 | 3.70 | 2.60 | 23.0 | 111.0 | 1.70 | 0.92 | 0.43 | 1.46 | 10.680000 | 0.85 | 1.56 | 695.0 | 2 |
167 | 12.82 | 3.37 | 2.30 | 19.5 | 88.0 | 1.48 | 0.66 | 0.40 | 0.97 | 10.260000 | 0.72 | 1.75 | 685.0 | 2 |
168 | 13.58 | 2.58 | 2.69 | 24.5 | 105.0 | 1.55 | 0.84 | 0.39 | 1.54 | 8.660000 | 0.74 | 1.80 | 750.0 | 2 |
169 | 13.40 | 4.60 | 2.86 | 25.0 | 112.0 | 1.98 | 0.96 | 0.27 | 1.11 | 8.500000 | 0.67 | 1.92 | 630.0 | 2 |
170 | 12.20 | 3.03 | 2.32 | 19.0 | 96.0 | 1.25 | 0.49 | 0.40 | 0.73 | 5.500000 | 0.66 | 1.83 | 510.0 | 2 |
171 | 12.77 | 2.39 | 2.28 | 19.5 | 86.0 | 1.39 | 0.51 | 0.48 | 0.64 | 9.899999 | 0.57 | 1.63 | 470.0 | 2 |
172 | 14.16 | 2.51 | 2.48 | 20.0 | 91.0 | 1.68 | 0.70 | 0.44 | 1.24 | 9.700000 | 0.62 | 1.71 | 660.0 | 2 |
173 | 13.71 | 5.65 | 2.45 | 20.5 | 95.0 | 1.68 | 0.61 | 0.52 | 1.06 | 7.700000 | 0.64 | 1.74 | 740.0 | 2 |
174 | 13.40 | 3.91 | 2.48 | 23.0 | 102.0 | 1.80 | 0.75 | 0.43 | 1.41 | 7.300000 | 0.70 | 1.56 | 750.0 | 2 |
175 | 13.27 | 4.28 | 2.26 | 20.0 | 120.0 | 1.59 | 0.69 | 0.43 | 1.35 | 10.200000 | 0.59 | 1.56 | 835.0 | 2 |
176 | 13.17 | 2.59 | 2.37 | 20.0 | 120.0 | 1.65 | 0.68 | 0.53 | 1.46 | 9.300000 | 0.60 | 1.62 | 840.0 | 2 |
177 | 14.13 | 4.10 | 2.74 | 24.5 | 96.0 | 2.05 | 0.76 | 0.56 | 1.35 | 9.200000 | 0.61 | 1.60 | 560.0 | 2 |
178 rows × 14 columns
xtrain,xtest,ytrain,ytest=train_test_split(wine.data,wine.target,test_size=0.3)
xtrain.shape
(124, 13)
1.1,1 criterion=“entropy”
clf=tree.DecisionTreeClassifier(criterion="entropy")
clf=clf.fit(xtrain,ytrain)
score=clf.score(xtest,ytest)
score
0.9259259259259259
1.1.2 criterion=“gini”
# 不填默认gini
clf=tree.DecisionTreeClassifier(criterion="gini")
clf=clf.fit(xtrain,ytrain)
score=clf.score(xtest,ytest)
score
0.9259259259259259
1.2 画树
- 使用 tree.export_graphviz
# 列出列项向量名字,便于查看理解结果
import graphviz
feature_name = ['酒精','苹果酸','灰','灰的碱性','镁','总酚','类黄酮','非黄烷类酚类','花青素','颜 色强度','色调','od280/od315稀释葡萄酒','脯氨酸']
dot_data=tree.export_graphviz(clf
,out_file=None
,class_names=["琴酒","雪莉","贝尔摩德"]
,filled=True #颜色填充
,rounded=True #圆角
)
graph=graphviz.Source(dot_data)
graph
1.3 探索决策树
# 各个特征对决策树的建立影响程度,特征重要性,越大影响程度越大
clf.feature_importances_
array([0.41133413, 0. , 0. , 0. , 0. ,
0. , 0.38205108, 0.02401924, 0. , 0. ,
0. , 0.05485876, 0.12773679])
[* zip(feature_name,clf.feature_importances_)]
[('酒精', 0.4113341349496296),
('苹果酸', 0.0),
('灰', 0.0),
('灰的碱性', 0.0),
('镁', 0.0),
('总酚', 0.0),
('类黄酮', 0.3820510756901232),
('非黄烷类酚类', 0.024019241220117216),
('花青素', 0.0),
('颜 色强度', 0.0),
('色调', 0.0),
('od280/od315稀释葡萄酒', 0.05485876081137878),
('脯氨酸', 0.12773678732875127)]
1.4 random_state & splitte
- random_state用来设置分枝中的随机模式的参数,默认None,在高维度时随机性会表现更明显,低维度的数据 (比如鸢尾花数据集),随机性几乎不会显现。输入任意整数,会一直长出同一棵树,让模型稳定下来。
- splitter也是用来控制决策树中的随机选项的,有两种输入值,输入”best",决策树在分枝时虽然随机,但是还是会 优先选择更重要的特征进行分枝(重要性可以通过属性feature_importances_查看),输入“random",决策树在 分枝时会更加随机,树会因为含有更多的不必要信息而更深更大,并因这些不必要信息而降低对训练集的拟合。这 也是防止过拟合的一种方式。当你预测到你的模型会过拟合,用这两个参数来帮助你降低树建成之后过拟合的可能 性。当然,树一旦建成,我们依然是使用剪枝参数来防止过拟合。
clf = tree.DecisionTreeClassifier(criterion="entropy",random_state=30)
clf = clf.fit(xtrain, ytrain)
score = clf.score(xtest, ytest) #返回预测的准确度
score
0.9444444444444444
dot_data=tree.export_graphviz(clf
,out_file=None
,class_names=["琴酒","雪莉","贝尔摩德"]
,filled=True #颜色填充
,rounded=True #圆角
)
graph=graphviz.Source(dot_data)
graph
1.5剪枝
- max_depth
- 限制树的最大深度,超过设定深度的树枝全部剪掉
- min_samples_leaf
- 一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分 枝就不会发生,或者,分枝会朝着满足每个子节点都包含min_samples_leaf个样本的方向去发生
- min_samples_split
- min_samples_split限定,一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则 分枝就不会发生。
clf=tree.DecisionTreeClassifier(criterion="gini"
,random_state=32
,max_depth=3
,min_samples_leaf=10
,min_samples_split=10
)
clf = clf.fit(xtrain, ytrain)
score = clf.score(xtest, ytest) #返回预测的准确度
score
0.8333333333333334
dot_data=tree.export_graphviz(clf
,out_file=None
,class_names=["琴酒","雪莉","贝尔摩德"]
,filled=True #颜色填充
,rounded=True #圆角
)
graph=graphviz.Source(dot_data)
graph
clf.score(xtrain,ytrain)
0.9435483870967742
clf.score(xtest,ytest)
0.8333333333333334
1.5使用网格调整参数
parameters={
"splitter":("best","random")
,"criterion":("gini","entropy")
,"max_depth":[*range(2,6)]
,'min_samples_leaf':[*range(1,50,5)]
, 'min_impurity_decrease':[*np.linspace(0,0.5,20)]
}
clf=DecisionTreeClassifier(random_state=25)
gs=GridSearchCV(clf,parameters,cv=10)
gs.fit(xtrain,ytrain)
GridSearchCV(cv=10, error_score=nan,
estimator=DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None,
criterion='gini', max_depth=None,
max_features=None,
max_leaf_nodes=None,
min_impurity_decrease=0.0,
min_impurity_split=None,
min_samples_leaf=1,
min_samples_split=2,
min_weight_fraction_leaf=0.0,
presort='deprecated',
random_state=25,
splitter='best'),
iid='dep...
0.23684210526315788,
0.2631578947368421,
0.2894736842105263,
0.3157894736842105,
0.3421052631578947,
0.3684210526315789,
0.39473684210526316,
0.42105263157894735,
0.4473684210526315,
0.47368421052631576, 0.5],
'min_samples_leaf': [1, 6, 11, 16, 21, 26, 31, 36, 41,
46],
'splitter': ('best', 'random')},
pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
scoring=None, verbose=0)
gs.best_params_
{'criterion': 'gini',
'max_depth': 4,
'min_impurity_decrease': 0.02631578947368421,
'min_samples_leaf': 1,
'splitter': 'best'}
clf=tree.DecisionTreeClassifier(
criterion="gini"
,max_depth=4
,min_impurity_decrease=0.02631578947368421
,min_samples_leaf=1
,splitter='best'
,random_state=25
)
clf = clf.fit(xtrain, ytrain)
score = clf.score(xtest, ytest) #返回预测的准确度
score
0.9444444444444444
clf.score(xtest,ytest)
0.9444444444444444
dot_data=tree.export_graphviz(clf
,out_file=None
,class_names=["琴酒","雪莉","贝尔摩德"]
,filled=True #颜色填充
,rounded=True #圆角
)
graph=graphviz.Source(dot_data)
graph
随机森林( RandomForestClassifie)
- 随机森林是非常具有代表性的Bagging集成算法,它的所有基评估器都是决策树,分类树组成的森林就叫做随机森 林分类器,回归树所集成的森林就叫做随机森林回归器。
2.1 n_estimators默认为100
- 这是森林中树木的数量,即基评估器的数量。这个参数对随机森林模型的精确性影响是单调的,n_estimators越 大,模型的效果往往越好。但是相应的,任何模型都有决策边界,n_estimators达到一定的程度之后,随机森林的 精确性往往不在上升或开始波动
wine=load_wine()
rfc=RandomForestClassifier(n_estimators=25)
rfc=cross_val_score(rfc,wine.data,wine.target,cv=10).mean()
plt.plot(range(1,11),rfc)
[<matplotlib.lines.Line2D at 0x11c2f2c3a48>]
0.9833333333333334
rfc
0.9833333333333334
- 菜菜的sklearn学习得到 https://live.bilibili.com/12582510