利用回归树DecisionTreeRegressor进行回归预测(复习10)

本文是个人学习笔记,内容主要基于回归树DecisionTreeRegressor对boston数据集学习回归模型和利用模型预测。

树模型可以解决非线性特征的问题,树模型不要求对特征标准化和统一量化(即数值型和类目型特征都可以直接被用到树模型的构建和预测过程),树模型可以直观地输出决策过程,使得预测结果具有可解释性。
使用树模型时要防止过拟合,对数据噪声的敏感度较高(预测稳定性较差),有训练数据构建最佳的树模型是NP难问题,因此实际操作时使用的类似贪婪算法的解法只能找到一些次优解。

回归树叶节点的数据类型是连续的,而分类树叶节点的数据类型是离散的。
回归树叶节点是一个个具体的值,而分类树叶节点是依据训练样本类别确定的预测类别。
回归树的叶节点返回的是“一团”训练数据的均值,而不是具体的、连续的预测值。

from sklearn.datasets import load_boston
boston=load_boston()
print(boston.DESCR)   #打印数据描述

这里写图片描述
这里写图片描述

print(boston.feature_names)

#Output:['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO' 'B' 'LSTAT']
from sklearn.cross_validation import train_test_split
import numpy as np
X=boston.data
y=boston.target
X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=33,test_size=0.25)
print('The max target value is',np.max(boston.target))
print('The min target value is',np.min(boston.target))
print('The average target value is',np.mean(boston.target))

这里写图片描述

from sklearn.tree import DecisionTreeRegressor

dtr=DecisionTreeRegressor()
dtr.fit(X_train,y_train)
dtr_y_predict=dtr.predict(X_test)
from sklearn.metrics import r2_score,mean_absolute_error,mean_squared_error

print('R-squared value of DecisionTreeRegressor:',dtr.score(X_test,y_test))
print('The mean squared error of DecisionTreeRegressor:',mean_squared_error(y_test,dtr_y_predict))
print('The mean absolute error of DecisionTreeRegressor:',mean_absolute_error(y_test,dtr_y_predict))

这里写图片描述

import sys
import os
os.environ["PATH"] += os.pathsep + 'D:\PYTHON35\Anaconda3.4.2\Lib\site-packages\graphviz-2.38\bin'
#'D:\PYTHON35\Anaconda3.4.2\Lib\site-packages\graphviz-2.38\bin'是解压缩graphviz-2.38.zip包后bin文件夹所在位置

%matplotlib inline
import numpy as np
from IPython.display import Image  
from sklearn import tree
import pydotplus 
import graphviz
dot_data = tree.export_graphviz(dtr, out_file=None, 
                         feature_names=boston.feature_names,  
                         class_names=['0','1'], 
                         filled=True, rounded=True,  
                         special_characters=True)   #feature_names格式是np.array
graph = pydotplus.graph_from_dot_data(dot_data)  
Image(graph.create_png())

这里写图片描述

  • 4
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值