机器学习 sklearn学习 第二天 回归树

from sklearn.datasets import load_boston  # 著名波士顿房价数据
from sklearn.model_selection import cross_val_score  # 交叉验证
from sklearn.tree import DecisionTreeRegressor
# todo:
# 几乎所有参数,属性和接口都和分类树一模一样
# 参数
# 参数1:criterion
# 输入 mse              使用均方误差
# 输入firedman_mse      费尔德曼均方误差
# 输入 mae              绝对平均误差

# todo:
#  交叉验证 :用于观察模型是否稳定,将数据划为n份,依次把1份作为测试集,n-1次作为训练集,得到n个评估指标,求平均
# boston = load_boston()
# refressor=DecisionTreeRegressor(random_state=0) # 实例化
# cross_val_score(refressor,boston.data,boston.targrt,cv=10 # 不需要划分测试集,训练集   cv告知划分多少份2,默认五份
#                 ,scoring="neg_mean_squared_error"  # 默认不填,默认返回r的平方,约接近1越好    填了就是负的均方误差  作为模型的评估标准
#                 ) # 进行交叉验证

# 用回归树拟合正弦曲线
import numpy as np
import matplotlib.pyplot as plt

rng = np.random.RandomState(1)  # 生成随机数种子
##  rng.rand(10)           ---> 结果  生成十个数的一维数组   范围0-1之间的随机数
##  rng.rang(2,3)          ---> 结果  生成2行3列的二维数组   范围0-1之间的随机数
##  5 * rng.rand(80, 1)    ---> 结果  生成80行1列的二维数组  范围 0-5之间的随机数
X = np.sort(5 * rng.rand(80, 1), axis=0)  # 排序,作为横坐标
##  y必须是一维的,需要通过.ravel()来降维
y = np.sin(X).ravel()
# plt.figure()
#                     s 图像大小  边框颜色    点颜色   y轴名称
# plt.scatter(X,y,s=20,edgecolors="black",c="darkorange",label="data")
# plt.show()

# todo: 加上噪声 模型训练集真实状况

# print(y[::5])
### 所有行所有列的每5个数,步长
y[::5] += 3 * (0.5 - rng.rand(16))
# print(y[::5])
# plt.figure()
# plt.scatter(X,y,s=20,edgecolors="black",c="darkorange",label="data")
# plt.show()

### 了解降维函数ravel()的用法
# np.random.random((2,1))
# np.random.random((2,1)).ravel()
# np.random.random((2,1)).ravel().shape

# todo:开始训练模型
#
regr1 = DecisionTreeRegressor(max_depth=2)
regr2 = DecisionTreeRegressor(max_depth=5)
regr1.fit(X, y)
regr2.fit(X, y)
# 定义测试集
# np.arange(开始点,结束点,步长)
# [:,np.newaxis] 类切片,增维      必须是二维的特征矩阵
x_test = np.arange(0.0,5.0,0.01)[:,np.newaxis]  # 生成0到5之间步长为0.01的有序二维矩阵
y_1=regr1.predict(x_test) # 返回每个测试样本的分类/回归结果
y_2=regr2.predict(x_test)
plt.figure()
plt.scatter(X,y,s=20,edgecolors="black",c="darkorange",label="data")
plt.plot(x_test,y_1,color="cornflowerblue",label='max_depth=2',linewidth=2)
plt.plot(x_test,y_2,color="yellowgreen",label='max_depth=5',linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()
发布了64 篇原创文章 · 获赞 13 · 访问量 2万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 编程工作室 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览