鸢尾花数据集,特征为连续值数据的决策树的多分类

1.导入工具

import pandas as pd
from sklearn import preprocessing
from sklearn import tree
from sklearn.datasets import load_iris

2.导入鸢尾花数据集,探索数据集
iris=load_iris()
#iris是一个字典,包含了数据、标签、标签名、数据描述等信息。可以通过键来索引对应值。
iris
#查看iris字典里的所有键
dir(iris)
iris.data
#150个数据,每个数据都有四个维度的特征,每个特征都是连续数值
iris.data.shape
#四个特征列名
iris.feature_names
#标签,0,1,2对应三种不同的鸢尾花
iris.target
#三种鸢尾花的名字
iris.target_names
鸢尾花数据集的描述说明信息
print(iris.DESCR)

3.构建决策树模型
dir(iris)
clf=tree.DecisionTreeClassifier(max_depth=4)
clf=clf.fit(iris.data, iris.target)
clf

4.可视化决策树
import pydotplus
from IPython.display import Image,display
dot_data=tree.export_graphviz(clf,
                             out_file=None,
                             feature_names=iris.feature_names,
                             class_names=iris.target_names,
                             filled=True,
                             rounded=True
                             )
graph=pydotplus.graph_from_dot_data(dot_data)
display(Image(graph.create_png()))


5.对整个训练集做预测
clf.predict(iris.data)

6.对单个样本做预测
#假设有一朵新的鸢尾花,四个特征分别为6.6cm,2.5cm,4.3cm,1,3cm。用训练好的决策树判断它属于哪一类鸢尾花。
import numpy as np
a1=np.array([6.6, 2.5, 4.3, 1.3])
a1
a1.shape
a1.reshape(1,-1).shape
clf.predict(a1.reshape(1,-1))
#属于第二类鸢尾花。
7.对多个样本做预测
a1=iris.data[30]
a2=iris.data[70]
a3=iris.data[120]
import numpy as np
b=np.row_stack((a1,a2,a3))
b
clf.predict(b)
import numpy as np
import matplotlib.pyplot as plt
%matplotlib.colors import ListedIormap
from matplotlib.colors import ListedColormap 
from sklearn import datasets
from sklearn import tree
iris=datasets.load_iris()
x=iris.data[:,2:4]#取出花瓣的长和宽
y=iris.target#取出标签
#计算散点图的上下界
x_min,x_max=x[:,0].min() -.5,  x[:,0].max()+.5
y_min,y_max=x[:,1].min() -.5,  x[:,1].max()+.5
#绘制边界
camo=cmap_light=ListedColormap(['#AAAAFF','#AAFFAA','#FFAAAA'])
h=.02
xx,yy=np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h))
clf=tree.DecisionTreeClassifier(max_depth=4)
clf=clf.fit(x, y)
Z=clf.predict(np.c_[xx.ravel(),yy.ravel()])
Z=Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx,yy,Z,cmap=cmap_light)
plt.scatter(x[:,0],x[:,1],c=y) 
plt.xlim(xx.min(),xx.max())
plt.ylim(yy.min(),yy.max())
plt.show()
 

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值