分类树的八个参数,一个属性,四个接口。
八个参数:
Criterion,
两个随机性相关的参数(random_state,splitter),
五个剪枝参数(max_depth,min_samples_split,min_samples_leaf,max_feature,min_impurity_decrease)
一个属性:feature_importances_
四个接口:fit,score,apply,predict
1.criterion
不填默认基尼系数,填写gini使用基尼系数,填写entropy使用信息增益
2.random_state
输入任意整数,会一直长出同一棵树,让模型稳定下来,可防止过拟合
3.splitter
有两种输入值,输入”best",决策树在分枝时虽然随机,但是还是会优先选择更重要的特征进行分枝(重要性可以通过属性feature_importances_查看),输入“random",决策树在分枝时会更加随机,树会因为含有更多的不必要信息而更深更大,并因这些不必要信息而降低对训练集的拟合。
4.max_deep
限制树的最大深度,超过设定深度的树枝全部剪掉这是用得最广泛的剪枝参数,在高维度低样本量时非常有效。决策树多生长一层,对样本量的需求会增加一倍,所以限制树深度能够有效地限制过拟合。
5.min_samples_leaf
一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生
6.min_samples_split
一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生
import graphviz
import pandas as pd
from sklearn import tree
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
# 得到红酒数据
wine = load_wine()
# 整合成一张表 dataframe格式 左右拼接
print(pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis = 1))
# 随机划分训练集和测试集 百分之70为训练数据 百分之30为测试数据分组
Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data,wine.target,test_size=0.3)
print(Xtrain.shape , Ytrain.shape) # ((124, 13), (124,))
print(Xtest.shape , Ytest.shape) # ((54, 13), (54,))
# 建立模型
clf = tree.DecisionTreeClassifier(criterion="entropy"
,random_state=30
,splitter="random"
,max_depth=3
,min_samples_leaf=5
,min_samples_split=15
)
# fit 接口一
clf = clf.fit(Xtrain, Ytrain)
# 训练集的拟合程度如何 若score_train 分数很高 score分数很低 则说明过拟合了
score = clf.score(Xtest, Ytest) # 返回预测的准确度 acc
score_train = clf.score(Xtrain, Ytrain)
score_cross = cross_val_score(clf,wine.data,wine.target,cv=10).mean() # 交叉验证
print('test:', score, 'train:', score_train, 'cross:', score_cross)
# predict返回每个测试样本的分类/回归结果
clf.predict(Xtest)
# 数据集中标签类型为英文 这里用中文重新定义
feature_name = ['酒精','苹果酸','灰','灰的碱性','镁','总酚','类黄酮','非黄烷类酚类','花青素','颜色强度','色调','od280/od315稀释葡萄酒','脯氨酸']
# 画树
dot_data = tree.export_graphviz(clf
,out_file = None
,feature_names = feature_name
,class_names= ['酒','雪莉','贝尔摩德'] # 重新定义类名
,filled = True # 不同类型填充颜色 不纯度越低颜色越深
,rounded = True # 方框形状
)
# 展示树
graph = graphviz.Source(dot_data)
print(graph)
# 特征的重要性(贡献率)重要属性
print(clf.feature_importances_)
print([*zip(feature_name,clf.feature_importances_)])
分类树参数列表