用Jupyter——鸢尾花的分类

 

加载鸢尾花数据集

[1]:

from  sklearn.datasets import load_iris

[2]:

iris=load_iris()

[3]:

dir(iris)

[3]:

['DESCR',
 'data',
 'data_module',
 'feature_names',
 'filename',
 'frame',
 'target',
 'target_names']

[4]:

x=iris.data
y=iris.target

[5]:

print(x.shape)
(150, 4)

[6]:

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

2.按照7:3切分训练集和测试集,种子设置为888

[7]:

X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=888)

[8]:

print(X_train.shape,X_test.shape)
(105, 4) (45, 4)

[9]:

print(y_train.shape)
(105,)

3.利用训练集训练ID3决策树,种子值设置同上

[12]:

clf=DecisionTreeClassifier(criterion='entropy',random_state=888)
clf=clf.fit(X_train,y_train)

4.测试模型的分类性能

[14]:

 
clf.score(X_test,y_test)

5.绘制决策树

[15]:

 
from sklearn import tree

[16]:

 
print(tree.export_text(clf))
|--- feature_3 <= 0.75
|   |--- class: 0
|--- feature_3 >  0.75
|   |--- feature_2 <= 4.75
|   |   |--- class: 1
|   |--- feature_2 >  4.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- feature_1 <= 3.10
|   |   |   |   |   |--- class: 2
|   |   |   |   |--- feature_1 >  3.10
|   |   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.95
|   |   |   |--- class: 2

[17]:

 
import matplotlib.pyplot as plt

[64]:

fig,ax=plt.subplots(figsize=(10,10))
tree.plot_tree(clf,feature_names=iris.feature_names,class_names=iris.target_names,filled=True)

[64]:

[Text(0.3333333333333333, 0.9166666666666666, 'petal width (cm) <= 0.75\nentropy = 1.582\nsamples = 105\nvalue = [34, 33, 38]\nclass = virginica'),
 Text(0.16666666666666666, 0.75, 'entropy = 0.0\nsamples = 34\nvalue = [34, 0, 0]\nclass = setosa'),
 Text(0.5, 0.75, 'petal length (cm) <= 4.75\nentropy = 0.996\nsamples = 71\nvalue = [0, 33, 38]\nclass = virginica'),
 Text(0.3333333333333333, 0.5833333333333334, 'entropy = 0.0\nsamples = 30\nvalue = [0, 30, 0]\nclass = versicolor'),
 Text(0.6666666666666666, 0.5833333333333334, 'petal length (cm) <= 4.95\nentropy = 0.378\nsamples = 41\nvalue = [0, 3, 38]\nclass = virginica'),
 Text(0.5, 0.4166666666666667, 'petal width (cm) <= 1.65\nentropy = 0.954\nsamples = 8\nvalue = [0, 3, 5]\nclass = virginica'),
 Text(0.3333333333333333, 0.25, 'entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]\nclass = versicolor'),
 Text(0.6666666666666666, 0.25, 'sepal width (cm) <= 3.1\nentropy = 0.65\nsamples = 6\nvalue = [0, 1, 5]\nclass = virginica'),
 Text(0.5, 0.08333333333333333, 'entropy = 0.0\nsamples = 5\nvalue = [0, 0, 5]\nclass = virginica'),
 Text(0.8333333333333334, 0.08333333333333333, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = versicolor'),
 Text(0.8333333333333334, 0.4166666666666667, 'entropy = 0.0\nsamples = 33\nvalue = [0, 0, 33]\nclass = virginica')]

 完整代码

from sklearn.datasets import load_iris

iris=load_iris()

dir(iris)

x=iris.data

y=iris.target

print(x.shape)

from sklearn.model_selection import train_test_split

from sklearn.tree import DecisionTreeClassifier

X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=888)

print(X_train.shape,X_test.shape)

print(y_train.shape)

clf=DecisionTreeClassifier(criterion='entropy',random_state=888)

clf=clf.fit(X_train,y_train)

clf.score(X_test,y_test)

from sklearn import tree

print(tree.export_text(clf))

import matplotlib.pyplot as plt

fig,ax=plt.subplots(figsize=(10,10))

tree.plot_tree(clf ,feature_names=iris.feature_names,class_names=iris.target_names,filled=True)

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

张謹礧

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值