引入相关模块
from sklearn.tree import DecisionTreeClassifier#决策树模型
import matplotlib.pyplot as plt#用于显示决策树
from sklearn.tree import plot_tree#用于显示决策树
import pandas as pd#用于读写数据
加载数据
iris=pd.read_csv('iris.csv')#读取文件
X=iris[['sepal_length','sepal_width']]#提取自变量
y=iris['species']#提取因变量
决策树拟合
tree_clf=DecisionTreeClassifier(max_depth=2)#初始化决策树
tree_clf.fit(X, y)#拟合
可视化显示
%matplotlib inline
plt.figure()
plot_tree(tree_clf,
filled=True,
feature_names=['sepal_length','sepal_width'],
class_names=['stosa','versicolor','virginica']
)
模型应用
a=[[3.5,4],[6.3,3]]#自变量数据集
print(tree_clf.predict(a))#数据归类
print(tree_clf.predict_proba(a))#数据归类概率
#输出:
#['setosa' 'virginica']
#[[0.97777778 0.02222222 0. ]
# [0. 0.29090909 0.70909091]]