from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
import pandas as pd
#加载数据集
wine=load_wine()
#如果wine是一张表,应该长这样:
print(wine.data.shape,wine.target)#打印数据的维度,lables类别
pd_wine=pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)#使用concat方法合并两个DataFrame
print(wine.feature_names,wine.target_names)
print(pd_wine.describe())
#划分数据集:分训练集和测试集
xtrain,xtest,ytrain,ytest=train_test_split(wine.data,wine.target,test_size=0.3)
#建立模型
#不从使用全部特征,而是随机选取一部分特征,从中选取不纯度相关指标最优的作为分枝用的节点。这样,每次生成的树也就不同了。
clf=tree.DecisionTreeClassifier(criterion="entropy",random_state=30)
clf=clf.fit(xtrain,ytrain)#训练
score=clf.score(xtest,ytest)#测试,返回准确度
print("准确度:",score)
#从中选取最优的树
#可视化
feature_name = ['酒精','苹果酸','灰','灰的碱性','镁','总酚','类黄酮','非黄烷类酚类','花青素','颜色强度','色调','od280/od315稀释葡萄酒','脯氨酸']
import graphviz
dot_data = tree.export_graphviz(clf
,feature_names= feature_name
,class_names=["琴酒","雪莉","贝尔摩德"]
,filled=True
,rounded=True
)
graph = graphviz.Source(dot_data)
print(graph)
输出结果:
(178, 13) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium', 'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins', 'color_intensity', 'hue', 'od280/od315_of_diluted_wines', 'proline'] ['class_0' 'class_1' 'class_2']
0 1 2 ... 11 12 0
count 178.000000 178.000000 178.000000 ... 178.000000 178.000000 178.000000
mean 13.000618 2.336348 2.366517 ... 2.611685 746.893258 0.938202
std 0.811827 1.117146 0.274344 ... 0.709990 314.907474 0.775035
min 11.030000 0.740000 1.360000 ... 1.270000 278.000000 0.000000
25% 12.362500 1.602500 2.210000 ... 1.937500 500.500000 0.000000
50% 13.050000 1.865000 2.360000 ... 2.780000 673.500000 1.000000
75% 13.677500 3.082500 2.557500 ... 3.170000 985.000000 2.000000
max 14.830000 5.800000 3.230000 ... 4.000000 1680.000000 2.000000
[8 rows x 14 columns]
准确度: 0.9444444444444444
digraph Tree {
node [shape=box, style="filled, rounded", color="black", fontname="helvetica"] ;
edge [fontname="helvetica"] ;
0 [label="类黄酮 <= 2.3\nentropy = 1.56\nsamples = 124\nvalue = [44, 49, 31]\nclass = 雪莉", fillcolor="#f3fdf7"] ;
1 [label="颜色强度 <= 3.825\nentropy = 0.993\nsamples = 69\nvalue = [0, 38, 31]\nclass = 雪莉", fillcolor="#dbfae8"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="entropy = 0.0\nsamples = 35\nvalue = [0, 35, 0]\nclass = 雪莉", fillcolor="#39e581"] ;
1 -> 2 ;
3 [label="类黄酮 <= 1.4\nentropy = 0.431\nsamples = 34\nvalue = [0, 3, 31]\nclass = 贝尔摩德", fillcolor="#8d4ce8"] ;
1 -> 3 ;
4 [label="entropy = 0.0\nsamples = 31\nvalue = [0, 0, 31]\nclass = 贝尔摩德", fillcolor="#8139e5"] ;
3 -> 4 ;
5 [label="entropy = 0.0\nsamples = 3\nvalue = [0, 3, 0]\nclass = 雪莉", fillcolor="#39e581"] ;
3 -> 5 ;
6 [label="脯氨酸 <= 679.0\nentropy = 0.722\nsamples = 55\nvalue = [44, 11, 0]\nclass = 琴酒", fillcolor="#eca06a"] ;
0 -> 6 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
7 [label="entropy = 0.0\nsamples = 10\nvalue = [0, 10, 0]\nclass = 雪莉", fillcolor="#39e581"] ;
6 -> 7 ;
8 [label="颜色强度 <= 3.435\nentropy = 0.154\nsamples = 45\nvalue = [44, 1, 0]\nclass = 琴酒", fillcolor="#e6843d"] ;
6 -> 8 ;
9 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = 雪莉", fillcolor="#39e581"] ;
8 -> 9 ;
10 [label="entropy = 0.0\nsamples = 44\nvalue = [44, 0, 0]\nclass = 琴酒", fillcolor="#e58139"] ;
8 -> 10 ;
}
进程已结束,退出代码为 0