scikit-learn决策树学习

 决策树(DTS)用于分类和回归的非参数化监督学习方法,目标就是通过从数据特征中推断学习简单的决策规则建立模型来预测目标变量的值。

例如,下面例子,决策树利用一组if-then-else决策规则从数据中拟合近似正弦曲线。树越深,决策规则越复杂,模型越合适,太深也易出现过拟合问题

import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
rng = np.random.RandomState(1)
#生成80行1列二维数组,并在行上升序排序
X = np.sort(5*rng.rand(80,1),axis =0)
#np.sin(X)得到数组平铺
y = np.sin(X).ravel()
#生成16个不在正弦曲线上的噪声点
y[::5] += 3*(0.5-rng.rand(16))


#DecisionTreeRegressor(criterion='mse', max_depth=5, max_features=None,
#           max_leaf_nodes=None, min_impurity_split=1e-07,
#           min_samples_leaf=1, min_samples_split=2,
#          min_weight_fraction_leaf=0.0, presort=False, random_state=None,
#           splitter='best')
#设置树最大深度为2
regr1 = DecisionTreeRegressor(max_depth=2)
#设置树最大深度为5
regr2 = DecisionTreeRegressor(max_depth=5)
#训练模型
regr1.fit(X,y)
regr2.fit(X,y)

#生成500x1数组
X_test = np.arange(0.0,5.0,0.01)[:,np.newaxis]
#利用模型预测
y1 = regr1.predict(X_test)
y2 = regr2.predict(X_test)

#预测结果可视化输出
plt.figure()
#散点图
plt.scatter(X,y,c='darkorange',label='data')
plt.plot(X_test ,y1,c='b',label='max_depth=2',lw=2)
plt.plot(X_test ,y2,c='y',label='max_depth=5',lw=2)
#设置X轴标签
plt.xlabel('data')
#设置y轴标签
plt.ylabel('target')
#设置图标题
plt.title('Decision Tree Regression')
#添加图例
plt.legend()
plt.show()


从上图可知:树的最大深度设置太高,从训练集中学到更多细节数据,易出现过拟合(overfit)问题。

树可视化输出方法:

#决策树分类问题
from sklearn.datasets import load_iris
from sklearn import  tree
#导入iris数据集
iris = load_iris()
#DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
#            max_features=None, max_leaf_nodes=None,
#            min_impurity_split=1e-07, min_samples_leaf=1,
#            min_samples_split=2, min_weight_fraction_leaf=0.0,
#            presort=False, random_state=None, splitter='best')
clf = tree.DecisionTreeClassifier()
#训练模型
clf = clf.fit(iris.data , iris.target)
#将树以graphviz格式写入iris.dot,然后在cmd命令模式下,进入到iris.dot目录下,执行dot -Tpdf iris.dot -o iris.pdf创建iris.pdf文件
with open('iris.dot','w') as f:
    f = tree.export_graphviz(clf,out_file =f)

#import os 
#os.unlink('iris.dot')#删除当前目录下iris.dot文件

#安装pydotplus模块,利用pydotplus可以生产其他支持的文件类型
import pydotplus
dot_data = tree.export_graphviz(clf,out_file = None)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf('iris1.pdf')#True

#利用Image函数直接在notebook中显示
from IPython.display import Image
dot_data = tree.export_graphviz(clf,out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled =True,rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值