sklearn中决策树都在‘tree’这个模块中,这个模块总共包含五类:
- tree.DecisionTreeClassifier 分类树
- tree.DecisionTreeRegressor 回归树
- tree.export_graphviz 画图专用
- tree.ExtraTreeClassifier 高随机版本的分类树
- tree.ExtraTreeRegressor 高随机版本的回归树
我们先开始学习分类树 DecisionTreeClassifier 的参数
1.criterion
衡量决策树分支的指标——不纯度
Criterion是用来决定不纯度的计算方法的
sklearn提供了两种选择:
(1)“entropy”,使用信息熵(Entropy)
(2)“gini”,使用基尼系数(Gini Impurity)
基尼系数比信息熵计算更快,而且信息熵更敏感,更精细,容易过拟合,默认是基尼系数。
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
wine=load_wine()
print(wine)#数据类型如下所示
{'data': array([[1.423e+01, 1.710e+00, 2.430e+00, ..., 1.040e+00, 3.920e+00,
1.065e+03],
[1.320e+01, 1.780e+00, 2.140e+00, ..., 1.050e+00, 3.400e+00,
1.050e+03],
[1.316e+01, 2.360e+00, 2.670e+00, ..., 1.030e+00, 3.170e+00,
1.185e+03],
...,
[1.327e+01, 4.280e+00, 2.260e+00, ..., 5.900e-01, 1.560e+00,
8.350e+02],
[1.317e+01, 2.590e+00, 2.370e+00, ..., 6.000e-01, 1.620e+00,
8.400e+02],
[1.413e+01, 4.100e+00, 2.740e+00, ..., 6.100e-01, 1.600e+00,
5.600e+02]]), 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2]), 'frame': None, 'target_names': array(['class_0', 'class_1', 'class_2'], dtype='<U7'), 'DESCR': '.. _wine_dataset:\n\nWine recognition dataset\n------------------------\n\n**Data Set Characteristics:**\n\n :Number of Instances: 178 (50 in each of three classes)\n :Number of Attributes: 13 numeric, predictive attributes and the class\n :Attribute Information:\n \t\t- Alcohol\n \t\t- Malic acid\n \t\t- Ash\n\t\t- Alcalinity of ash \n \t\t- Magnesium\n\t\t- Total phenols\n \t\t- Flavanoids\n \t\t- Nonflavanoid phenols\n \t\t- Proanthocyanins\n\t\t- Color intensity\n \t\t- Hue\n \t\t- OD280/OD315 of diluted wines\n \t\t- Proline\n\n - class:\n - class_0\n - class_1\n - class_2\n\t\t\n :Summary Statistics:\n \n ============================= ==== ===== ======= =====\n Min Max Mean SD\n ============================= ==== ===== ======= =====\n Alcohol: 11.0 14.8 13.0 0.8\n Malic Acid: 0.74 5.80 2.34 1.12\n Ash: 1.36 3.23 2.36 0.27\n Alcalinity of Ash: 10.6 30.0 19.5 3.3\n Magnesium: 70.0 162.0 99.7 14.3\n Total Phenols: 0.98 3.88 2.29 0.63\n Flavanoids: 0.34 5.08 2.03 1.00\n Nonflavanoid Phenols: 0.13 0.66 0.36 0.12\n Proanthocyanins: 0.41 3.58 1.59 0.57\n Colour Intensity: 1.3 13.0 5.1 2.3\n Hue: 0.48 1.71 0.96 0.23\n OD280/OD315 of diluted wines: 1.27 4.00 2.61 0.71\n Proline: 278 1680 746 315\n ============================= ==== ===== ======= =====\n\n :Missing Attribute Values: None\n :Class Distribution: class_0 (59), class_1 (71), class_2 (48)\n :Creator: R.A. Fisher\n :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n :Date: July, 1988\n\nThis is a copy of UCI ML Wine recognition datasets.\nhttps://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data\n\nThe data is the results of a chemical analysis of wines grown in the same\nregion in Italy by three different cultivators. There are thirteen different\nmeasurements taken for different constituents found in the three types of\nwine.\n\nOriginal Owners: \n\nForina, M. et al, PARVUS - \nAn Extendible Package for Data Exploration, Classification and Correlation. \nInstitute of Pharmaceutical and Food Analysis and Technologies,\nVia Brigata Salerno, 16147 Genoa, Italy.\n\nCitation:\n\nLichman, M. (2013). UCI Machine Learning Repository\n[https://archive.ics.uci.edu/ml]. Irvine, CA: University of California,\nSchool of Information and Computer Science. \n\n.. topic:: References\n\n (1) S. Aeberhard, D. Coomans and O. de Vel, \n Comparison of Classifiers in High Dimensional Settings, \n Tech. Rep. no. 92-02, (1992), Dept. of Computer Science and Dept. of \n Mathematics and Statistics, James Cook University of North Queensland. \n (Also submitted to Technometrics). \n\n The data was used with many others for comparing various \n classifiers. The classes are separable, though only RDA \n has achieved 100% correct classification. \n (RDA : 100%, QDA 99.4%, LDA 98.9%, 1NN 96.1% (z-transformed data)) \n (All results using the leave-one-out technique) \n\n (2) S. Aeberhard, D. Coomans and O. de Vel, \n "THE CLASSIFICATION PERFORMANCE OF RDA" \n Tech. Rep. no. 92-01, (1992), Dept. of Computer Science and Dept. of \n Mathematics and Statistics, James Cook University of North Queensland. \n (Also submitted to Journal of Chemometrics).\n', 'feature_names': ['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium', 'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins', 'color_intensity', 'hue', 'od280/od315_of_diluted_wines', 'proline']}
import pandas as pd
pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)
0 1 2 3 4 5 ... 8 9 10 11 12 0
0 14.23 1.71 2.43 15.6 127.0 2.80 ... 2.29 5.64 1.04 3.92 1065.0 0
1 13.20 1.78 2.14 11.2 100.0 2.65 ... 1.28 4.38 1.05 3.40 1050.0 0
2 13.16 2.36 2.67 18.6 101.0 2.80 ... 2.81 5.68 1.03 3.17 1185.0 0
3 14.37 1.95 2.50 16.8 113.0 3.85 ... 2.18 7.80 0.86 3.45 1480.0 0
4 13.24 2.59 2.87 21.0 118.0 2.80 ... 1.82 4.32 1.04 2.93 735.0 0
.. ... ... ... ... ... ... ... ... ... ... ... ... ..
173 13.71 5.65 2.45 20.5 95.0 1.68 ... 1.06 7.70 0.64 1.74 740.0 2
174 13.40 3.91 2.48 23.0 102.0 1.80 ... 1.41 7.30 0.70 1.56 750.0 2
175 13.27 4.28 2.26 20.0 120.0 1.59 ... 1.35 10.20 0.59 1.56 835.0 2
176 13.17 2.59 2.37 20.0 120.0 1.65 ... 1.46 9.30 0.60 1.62 840.0 2
177 14.13 4.10 2.74 24.5 96.0 2.05 ... 1.35 9.20 0.61 1.60 560.0 2
[178 rows x 14 columns]
Xtrain,Xtest,Ytrain,Ytest=train_test_split(wine.data,wine.target,test_size=0.3)
clf=tree.DecisionTreeClassifier(criterion="entropy")
clf=clf.fit(Xtrain,Ytrain)
score=clf.score(Xtest,Ytest)
score
0.8518518518518519
如果我们想把树画出来
import graphviz
feature_name=wine.feature_names
dot_data=tree.export_graphviz(clf,feature_names=feature_name,class_names=['琴酒','雪莉','贝尔摩德'],filled=True,rounded=True)
graph=graphviz.Source(dot_data)
clf.feature_importances_#属性的重要程度
2.random_state和splitter
两个都是控制决策树的随机选项,可以防止过拟合。
3.剪枝参数
score=clf.score(Xtrain,Ytrain)
print(score)
1.0
为了让决策树有更好的泛化,我们要对决策树进行剪枝。剪枝策略对决策树的影响巨大,正确的剪枝策略是优化决策树算法的核心。
- max_depth
限制树的最大深度,超过设定深度的树枝全部剪掉 。在高维度低样本量时非常有效。实际使用时,建议从=3开始尝试,查看拟合效果。
- min_samples_leaf 和 min_samples_split
min_samples_leaf ,分支会朝着满足每个子节点都包含min_samples_leaf个样本的方向去发生,建议从=5开始使用。
min_samples_split,一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生。
- min_impurity_decrease
min_impurity_decrease限制信息增益的大小,信息增益小于设定数值的分支不会发生